mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
2x faster (rms) norm cuda kernels (3.7% e2e improvement) (#2985)
* 2x faster (rms) norm cuda kernels * Fix code style
This commit is contained in:
parent
cf9b08485c
commit
35195689cd
89
ggml-cuda.cu
89
ggml-cuda.cu
@ -464,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
|||||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||||
|
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int block_size>
|
||||||
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
const float eps = 1e-5f;
|
const float eps = 1e-5f;
|
||||||
|
|
||||||
float mean = 0.0f;
|
float2 mean_var = make_float2(0.f, 0.f);
|
||||||
float var = 0.0f;
|
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const float xi = x[row*ncols + col];
|
const float xi = x[row*ncols + col];
|
||||||
mean += xi;
|
mean_var.x += xi;
|
||||||
var += xi * xi;
|
mean_var.y += xi * xi;
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums
|
// sum up partial sums
|
||||||
#pragma unroll
|
mean_var = warp_reduce_sum(mean_var);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
if (block_size > WARP_SIZE) {
|
||||||
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
|
__shared__ float2 s_sum[32];
|
||||||
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = mean_var;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
mean_var = s_sum[lane_id];
|
||||||
|
mean_var = warp_reduce_sum(mean_var);
|
||||||
}
|
}
|
||||||
|
|
||||||
mean /= ncols;
|
const float mean = mean_var.x / ncols;
|
||||||
var = var / ncols - mean * mean;
|
const float var = mean_var.y / ncols - mean * mean;
|
||||||
const float inv_var = rsqrtf(var + eps);
|
const float inv_std = rsqrtf(var + eps);
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
|
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int block_size>
|
||||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
float tmp = 0.0f; // partial sum for thread in warp
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const float xi = x[row*ncols + col];
|
const float xi = x[row*ncols + col];
|
||||||
tmp += xi * xi;
|
tmp += xi * xi;
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums
|
// sum up partial sums
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
if (block_size > WARP_SIZE) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
__shared__ float s_sum[32];
|
||||||
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
tmp = s_sum[lane_id];
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
const float mean = tmp / ncols;
|
const float mean = tmp / ncols;
|
||||||
const float scale = rsqrtf(mean + eps);
|
const float scale = rsqrtf(mean + eps);
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
dst[row*ncols + col] = scale * x[row*ncols + col];
|
dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -4203,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
|||||||
|
|
||||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
if (ncols < 1024) {
|
||||||
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
|
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
|
} else {
|
||||||
|
const dim3 block_dims(1024, 1, 1);
|
||||||
|
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
if (ncols < 1024) {
|
||||||
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
|
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||||
|
} else {
|
||||||
|
const dim3 block_dims(1024, 1, 1);
|
||||||
|
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
|
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
|
||||||
|
Loading…
Reference in New Issue
Block a user