mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
cuda : fix warp reduction initialization of shared mem
This commit is contained in:
parent
6b86bcffac
commit
68e02c0d58
23
ggml-cuda.cu
23
ggml-cuda.cu
@ -4733,6 +4733,11 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
|||||||
|
|
||||||
const int block_size = blockDim.x;
|
const int block_size = blockDim.x;
|
||||||
|
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
@ -4744,13 +4749,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
|||||||
// find the max value in the block
|
// find the max value in the block
|
||||||
max_val = warp_reduce_max(max_val);
|
max_val = warp_reduce_max(max_val);
|
||||||
if (block_size > WARP_SIZE) {
|
if (block_size > WARP_SIZE) {
|
||||||
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
|
if (warp_id == 0) {
|
||||||
int warp_id = threadIdx.x / WARP_SIZE;
|
buf[lane_id] = -INFINITY;
|
||||||
int lane_id = threadIdx.x % WARP_SIZE;
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
buf[warp_id] = max_val;
|
buf[warp_id] = max_val;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
max_val = buf[lane_id];
|
max_val = buf[lane_id];
|
||||||
max_val = warp_reduce_max(max_val);
|
max_val = warp_reduce_max(max_val);
|
||||||
}
|
}
|
||||||
@ -4768,13 +4776,16 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
|||||||
// find the sum of exps in the block
|
// find the sum of exps in the block
|
||||||
tmp = warp_reduce_sum(tmp);
|
tmp = warp_reduce_sum(tmp);
|
||||||
if (block_size > WARP_SIZE) {
|
if (block_size > WARP_SIZE) {
|
||||||
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
|
if (warp_id == 0) {
|
||||||
int warp_id = threadIdx.x / WARP_SIZE;
|
buf[lane_id] = 0.f;
|
||||||
int lane_id = threadIdx.x % WARP_SIZE;
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
buf[warp_id] = tmp;
|
buf[warp_id] = tmp;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
tmp = buf[lane_id];
|
tmp = buf[lane_id];
|
||||||
tmp = warp_reduce_sum(tmp);
|
tmp = warp_reduce_sum(tmp);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user