mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
Fix CUDA softmax by subtracting max value before exp (#2665)
This commit is contained in:
parent
deb7dfca4b
commit
800c9635b4
37
ggml-cuda.cu
37
ggml-cuda.cu
@ -3979,24 +3979,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||||||
|
|
||||||
// the CUDA soft max implementation differs from the CPU implementation
|
// the CUDA soft max implementation differs from the CPU implementation
|
||||||
// instead of doubles floats are used
|
// instead of doubles floats are used
|
||||||
// values are also not normalized to the maximum value by subtracting it in the exponential function
|
|
||||||
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
|
|
||||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
||||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
const int block_size = blockDim.y;
|
const int block_size = blockDim.y;
|
||||||
const int tid = threadIdx.y;
|
const int tid = threadIdx.y;
|
||||||
|
|
||||||
float tmp = 0.0;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
for (int block_start = 0; block_start < ncols; block_start += block_size) {
|
|
||||||
const int col = block_start + tid;
|
|
||||||
|
|
||||||
if (col >= ncols) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const int i = row*ncols + col;
|
const int i = row*ncols + col;
|
||||||
const float val = expf(x[i]);
|
max_val = max(max_val, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the max value in the block
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
|
||||||
|
}
|
||||||
|
|
||||||
|
float tmp = 0.f;
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
const int i = row*ncols + col;
|
||||||
|
const float val = expf(x[i] - max_val);
|
||||||
tmp += val;
|
tmp += val;
|
||||||
dst[i] = val;
|
dst[i] = val;
|
||||||
}
|
}
|
||||||
@ -4007,15 +4012,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
|
|||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int block_start = 0; block_start < ncols; block_start += block_size) {
|
const float inv_tmp = 1.f / tmp;
|
||||||
const int col = block_start + tid;
|
|
||||||
|
|
||||||
if (col >= ncols) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const int i = row*ncols + col;
|
const int i = row*ncols + col;
|
||||||
dst[i] /= tmp;
|
dst[i] *= inv_tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user