From 88519fbf97abcbf6b23de1027f4d2ac76cf50166 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Nov 2023 15:34:20 +0200 Subject: [PATCH] cuda : implement soft_max_ext --- ggml-cuda.cu | 35 +++++++++++++++++++++-------------- ggml.c | 6 ++++++ llama.cpp | 1 + 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5b80e4ae3..628f2dcbc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int // the CUDA soft max implementation differs from the CPU implementation // instead of doubles floats are used -static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { - const int row = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { + const int rowx = blockDim.x*blockIdx.x + threadIdx.x; + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension const int block_size = blockDim.y; const int tid = threadIdx.y; float max_val = -INFINITY; for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; - max_val = max(max_val, x[i]); + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); } // find the max value in the block @@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol float tmp = 0.f; for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; - const float val = expf(x[i] - max_val); + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); tmp += val; - dst[i] = val; + dst[ix] = val; } // sum up partial sums @@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol const float inv_tmp = 1.f / tmp; for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; + const int i = rowx*ncols + col; dst[i] *= inv_tmp; } } @@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { const dim3 block_dims(1, WARP_SIZE, 1); const dim3 block_nums(nrows_x, 1, 1); - soft_max_f32<<>>(x, dst, ncols_x); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); } static void im2col_f32_f16_cuda(const float * x, half * dst, @@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0; - soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); + float scale = 1.0f; + memcpy(&scale, dst->op_params, sizeof(float)); + + soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); - (void) src1; (void) dst; - (void) src1_dd; } inline void ggml_cuda_op_scale( diff --git a/ggml.c b/ggml.c index a0b04cbeb..788cabd84 100644 --- a/ggml.c +++ b/ggml.c @@ -4829,6 +4829,12 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * mask, float scale, bool inplace) { + if (mask) { + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + } + bool is_node = false; if (a->grad) { diff --git a/llama.cpp b/llama.cpp index ba837e26f..2c13aeb50 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5048,6 +5048,7 @@ static const std::unordered_map k_offload_map { "kq_scaled_alibi", OFFLOAD_FUNC_KQ }, { "kq_masked", OFFLOAD_FUNC_KQ }, { "kq_soft_max", OFFLOAD_FUNC_V }, + { "kq_soft_max_ext", OFFLOAD_FUNC_V }, { "v", OFFLOAD_FUNC_V }, { "kqv", OFFLOAD_FUNC_V }, { "kqv_merged", OFFLOAD_FUNC_V },