diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 98343d208..9019a849f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4733,6 +4733,11 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int block_size = blockDim.x; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + float max_val = -INFINITY; for (int col = tid; col < ncols; col += block_size) { @@ -4744,13 +4749,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds // find the max value in the block max_val = warp_reduce_max(max_val); if (block_size > WARP_SIZE) { - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + if (warp_id == 0) { + buf[lane_id] = -INFINITY; + } + __syncthreads(); + if (lane_id == 0) { buf[warp_id] = max_val; } __syncthreads(); + max_val = buf[lane_id]; max_val = warp_reduce_max(max_val); } @@ -4768,13 +4776,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds // find the sum of exps in the block tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + if (warp_id == 0) { + buf[lane_id] = 0.f; + } + __syncthreads(); + if (lane_id == 0) { buf[warp_id] = tmp; } __syncthreads(); + tmp = buf[lane_id]; tmp = warp_reduce_sum(tmp); }