diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 53a478aa9..080193cbd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -502,6 +502,31 @@ static size_t g_scratch_offset = 0; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + } + return a; +} + +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +} + static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -578,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) { dst[i] = x[i] * x[i]; } -static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); - a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); - } - return a; -} - template static __global__ void norm_f32(const float * x, float * dst, const int ncols) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -625,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __device__ __forceinline__ float warp_reduce_sum(float x) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); - } - return x; -} - template static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -4718,7 +4726,6 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU } -// TODO: maybe can be improved with some warp-based primitives static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { const int tid = threadIdx.x; const int rowx = blockIdx.x; @@ -4726,24 +4733,26 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int block_size = blockDim.x; - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE]; - - buf[tid] = -INFINITY; + float max_val = -INFINITY; for (int col = tid; col < ncols; col += block_size) { const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f)); + max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); } - __syncthreads(); - // find the max value in the block - for (int i = block_size/2; i > 0; i >>= 1) { - if (tid < i) { - buf[tid] = max(buf[tid], buf[tid + i]); + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + buf[warp_id] = max_val; } __syncthreads(); + max_val = buf[lane_id]; + max_val = warp_reduce_max(max_val); } float tmp = 0.f; @@ -4751,26 +4760,26 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds for (int col = tid; col < ncols; col += block_size) { const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]); + const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); tmp += val; dst[ix] = val; } - __syncthreads(); - - buf[tid] = tmp; - - __syncthreads(); - - // sum up partial sums - for (int i = block_size/2; i > 0; i >>= 1) { - if (tid < i) { - buf[tid] += buf[tid + i]; + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + buf[warp_id] = tmp; } __syncthreads(); + tmp = buf[lane_id]; + tmp = warp_reduce_sum(tmp); } - const float inv_tmp = 1.f / buf[0]; + const float inv_tmp = 1.f / tmp; for (int col = tid; col < ncols; col += block_size) { const int i = rowx*ncols + col;