mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
cuda : adapt soft_max to F16 mask and pos
This commit is contained in:
parent
3e318e764f
commit
08e69c5008
@ -1,7 +1,7 @@
|
||||
#include "softmax.cuh"
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||
static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
@ -43,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
||||
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
@ -114,7 +114,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
||||
}
|
||||
}
|
||||
|
||||
static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
||||
static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
@ -168,14 +168,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = src1 ? (const float *)src1->data : nullptr;
|
||||
const half * src1_d = src1 ? (const half *)src1->data : nullptr;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
@ -188,13 +188,13 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
// positions tensor
|
||||
float * src2_dd = nullptr;
|
||||
half * src2_dd = nullptr;
|
||||
|
||||
ggml_tensor * src2 = dst->src[2];
|
||||
const bool use_src2 = src2 != nullptr;
|
||||
|
||||
if (use_src2) {
|
||||
src2_dd = (float *)src2->data;
|
||||
src2_dd = (half *)src2->data;
|
||||
}
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||
|
Loading…
Reference in New Issue
Block a user