cuda : do warp-based block reduce

This commit is contained in:
Georgi Gerganov 2023-11-30 20:36:08 +02:00
parent c7c8dabcf7
commit 62532c05aa
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -502,6 +502,31 @@ static size_t g_scratch_offset = 0;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
}
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -578,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] * x[i]; dst[i] = x[i] * x[i];
} }
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
template <int block_size> template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols) { static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
@ -625,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
} }
} }
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
template <int block_size> template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
@ -4718,7 +4726,6 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
} }
// TODO: maybe can be improved with some warp-based primitives
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int rowx = blockIdx.x; const int rowx = blockIdx.x;
@ -4726,24 +4733,26 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
const int block_size = blockDim.x; const int block_size = blockDim.x;
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE]; float max_val = -INFINITY;
buf[tid] = -INFINITY;
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
const int ix = rowx*ncols + col; const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col; const int iy = rowy*ncols + col;
buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f)); max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
} }
__syncthreads();
// find the max value in the block // find the max value in the block
for (int i = block_size/2; i > 0; i >>= 1) { max_val = warp_reduce_max(max_val);
if (tid < i) { if (block_size > WARP_SIZE) {
buf[tid] = max(buf[tid], buf[tid + i]); __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
buf[warp_id] = max_val;
} }
__syncthreads(); __syncthreads();
max_val = buf[lane_id];
max_val = warp_reduce_max(max_val);
} }
float tmp = 0.f; float tmp = 0.f;
@ -4751,26 +4760,26 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
const int ix = rowx*ncols + col; const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col; const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]); const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val; tmp += val;
dst[ix] = val; dst[ix] = val;
} }
__syncthreads(); // find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
buf[tid] = tmp; if (block_size > WARP_SIZE) {
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
__syncthreads(); int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
// sum up partial sums if (lane_id == 0) {
for (int i = block_size/2; i > 0; i >>= 1) { buf[warp_id] = tmp;
if (tid < i) {
buf[tid] += buf[tid + i];
} }
__syncthreads(); __syncthreads();
tmp = buf[lane_id];
tmp = warp_reduce_sum(tmp);
} }
const float inv_tmp = 1.f / buf[0]; const float inv_tmp = 1.f / tmp;
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
const int i = rowx*ncols + col; const int i = rowx*ncols + col;