cuda : implement soft_max_ext

This commit is contained in:
Georgi Gerganov 2023-11-29 15:34:20 +02:00
parent e89597c062
commit 88519fbf97
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 28 additions and 14 deletions

View File

@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
// the CUDA soft max implementation differs from the CPU implementation // the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used // instead of doubles floats are used
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int row = blockDim.x*blockIdx.x + threadIdx.x; const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = blockDim.y; const int block_size = blockDim.y;
const int tid = threadIdx.y; const int tid = threadIdx.y;
float max_val = -INFINITY; float max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col; const int ix = rowx*ncols + col;
max_val = max(max_val, x[i]); const int iy = rowy*ncols + col;
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
} }
// find the max value in the block // find the max value in the block
@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
float tmp = 0.f; float tmp = 0.f;
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col; const int ix = rowx*ncols + col;
const float val = expf(x[i] - max_val); const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val; tmp += val;
dst[i] = val; dst[ix] = val;
} }
// sum up partial sums // sum up partial sums
@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
const float inv_tmp = 1.f / tmp; const float inv_tmp = 1.f / tmp;
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col; const int i = rowx*ncols + col;
dst[i] *= inv_tmp; dst[i] *= inv_tmp;
} }
} }
@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past); diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
} }
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
const dim3 block_dims(1, WARP_SIZE, 1); const dim3 block_dims(1, WARP_SIZE, 1);
const dim3 block_nums(nrows_x, 1, 1); const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x); soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
} }
static void im2col_f32_f16_cuda(const float * x, half * dst, static void im2col_f32_f16_cuda(const float * x, half * dst,
@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0;
soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
(void) src1;
(void) dst; (void) dst;
(void) src1_dd;
} }
inline void ggml_cuda_op_scale( inline void ggml_cuda_op_scale(

6
ggml.c
View File

@ -4829,6 +4829,12 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * mask, struct ggml_tensor * mask,
float scale, float scale,
bool inplace) { bool inplace) {
if (mask) {
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
}
bool is_node = false; bool is_node = false;
if (a->grad) { if (a->grad) {

View File

@ -5048,6 +5048,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQ }, { "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
{ "kq_masked", OFFLOAD_FUNC_KQ }, { "kq_masked", OFFLOAD_FUNC_KQ },
{ "kq_soft_max", OFFLOAD_FUNC_V }, { "kq_soft_max", OFFLOAD_FUNC_V },
{ "kq_soft_max_ext", OFFLOAD_FUNC_V },
{ "v", OFFLOAD_FUNC_V }, { "v", OFFLOAD_FUNC_V },
{ "kqv", OFFLOAD_FUNC_V }, { "kqv", OFFLOAD_FUNC_V },
{ "kqv_merged", OFFLOAD_FUNC_V }, { "kqv_merged", OFFLOAD_FUNC_V },