cuda : fix warp reduction initialization of shared mem

This commit is contained in:
Georgi Gerganov 2023-11-30 21:39:48 +02:00
parent 6b86bcffac
commit 68e02c0d58
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -4733,6 +4733,11 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
const int block_size = blockDim.x; const int block_size = blockDim.x;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
float max_val = -INFINITY; float max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
@ -4744,13 +4749,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
// find the max value in the block // find the max value in the block
max_val = warp_reduce_max(max_val); max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) { if (block_size > WARP_SIZE) {
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; if (warp_id == 0) {
int warp_id = threadIdx.x / WARP_SIZE; buf[lane_id] = -INFINITY;
int lane_id = threadIdx.x % WARP_SIZE; }
__syncthreads();
if (lane_id == 0) { if (lane_id == 0) {
buf[warp_id] = max_val; buf[warp_id] = max_val;
} }
__syncthreads(); __syncthreads();
max_val = buf[lane_id]; max_val = buf[lane_id];
max_val = warp_reduce_max(max_val); max_val = warp_reduce_max(max_val);
} }
@ -4768,13 +4776,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
// find the sum of exps in the block // find the sum of exps in the block
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) { if (block_size > WARP_SIZE) {
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; if (warp_id == 0) {
int warp_id = threadIdx.x / WARP_SIZE; buf[lane_id] = 0.f;
int lane_id = threadIdx.x % WARP_SIZE; }
__syncthreads();
if (lane_id == 0) { if (lane_id == 0) {
buf[warp_id] = tmp; buf[warp_id] = tmp;
} }
__syncthreads(); __syncthreads();
tmp = buf[lane_id]; tmp = buf[lane_id];
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
} }