llama : adapt to F16 KQ_pos

This commit is contained in:
Georgi Gerganov 2024-02-19 13:10:24 +02:00
parent 31109ca00a
commit f249c997a8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 13 additions and 8 deletions

View File

@ -6232,7 +6232,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
const int ix = rowx*ncols + col; const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col; const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? __half2float(slope*pos[col]) : 0.0f); const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f);
vals[col] = val; vals[col] = val;
max_val = max(max_val, val); max_val = max(max_val, val);

2
ggml.c
View File

@ -5192,7 +5192,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(mask->type == GGML_TYPE_F16);
GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask)); GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(ggml_can_repeat_rows(mask, a)); GGML_ASSERT(mask->ne[1] >= a->ne[1]);
} }
if (pos) { if (pos) {

View File

@ -102,7 +102,7 @@
#define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_NODES 8192
#define LLAMA_MAX_EXPERTS 8 #define LLAMA_MAX_EXPERTS 8
#define LLAMA_FLASH_ATTN //#define LLAMA_FLASH_ATTN
// //
// logging // logging
@ -4831,6 +4831,11 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * cur; struct ggml_tensor * cur;
#if defined(LLAMA_FLASH_ATTN) #if defined(LLAMA_FLASH_ATTN)
GGML_UNUSED(model);
GGML_UNUSED(n_ctx);
GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
// split cached v into n_head heads (not transposed) // split cached v into n_head heads (not transposed)
struct ggml_tensor * v = struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il], ggml_view_3d(ctx, kv.v_l[il],
@ -5260,7 +5265,7 @@ struct llm_build_context {
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
// positions of the tokens in the KV cache // positions of the tokens in the KV cache
struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16);
cb(KQ_pos, "KQ_pos", -1); cb(KQ_pos, "KQ_pos", -1);
// shift the entire K-cache if needed // shift the entire K-cache if needed
@ -5804,7 +5809,7 @@ struct llm_build_context {
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
// positions of the tokens in the KV cache // positions of the tokens in the KV cache
struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16);
cb(KQ_pos, "KQ_pos", -1); cb(KQ_pos, "KQ_pos", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -6043,7 +6048,7 @@ struct llm_build_context {
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
// positions of the tokens in the KV cache // positions of the tokens in the KV cache
struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16);
cb(KQ_pos, "KQ_pos", -1); cb(KQ_pos, "KQ_pos", -1);
inpL = llm_build_norm(ctx0, inpL, hparams, inpL = llm_build_norm(ctx0, inpL, hparams,
@ -6140,7 +6145,7 @@ struct llm_build_context {
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
// positions of the tokens in the KV cache // positions of the tokens in the KV cache
struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16);
cb(KQ_pos, "KQ_pos", -1); cb(KQ_pos, "KQ_pos", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {

View File

@ -1505,7 +1505,7 @@ struct test_attn : public test_case {
struct ggml_tensor * cur; struct ggml_tensor * cur;
cur = ggml_mul_mat (ctx, k, q); cur = ggml_mul_mat (ctx, k, q);
cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f);
cur = ggml_mul_mat (ctx, v, cur); cur = ggml_mul_mat (ctx, v, cur);
cur = ggml_permute (ctx, cur, 0, 2, 1, 3); cur = ggml_permute (ctx, cur, 0, 2, 1, 3);
cur = ggml_cont_2d (ctx, cur, hs*nh, nb); cur = ggml_cont_2d (ctx, cur, hs*nh, nb);