cuda : fix warp reduction initialization of shared mem

This commit is contained in:
Georgi Gerganov 2023-11-30 21:39:48 +02:00
parent 6b86bcffac
commit 68e02c0d58
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -4733,6 +4733,11 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
const int block_size = blockDim.x;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
float max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) {
@ -4744,13 +4749,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (warp_id == 0) {
buf[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf[warp_id] = max_val;
}
__syncthreads();
max_val = buf[lane_id];
max_val = warp_reduce_max(max_val);
}
@ -4768,13 +4776,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (warp_id == 0) {
buf[lane_id] = 0.f;
}
__syncthreads();
if (lane_id == 0) {
buf[warp_id] = tmp;
}
__syncthreads();
tmp = buf[lane_id];
tmp = warp_reduce_sum(tmp);
}