mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-30 22:03:03 +01:00
llama : do not use KV cache for non-causal models
ggml-ci
This commit is contained in:
parent
d0347840c1
commit
eb42596277
@ -13,7 +13,7 @@ async def main():
|
|||||||
model_url = "http://127.0.0.1:6900"
|
model_url = "http://127.0.0.1:6900"
|
||||||
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
||||||
url= f"{model_url}/embedding",
|
url= f"{model_url}/embedding",
|
||||||
json= {"content": str(i)*32}
|
json= {"content": str(0)*32}
|
||||||
) for i in range(n)])
|
) for i in range(n)])
|
||||||
|
|
||||||
for response in responses:
|
for response in responses:
|
||||||
|
@ -2044,6 +2044,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||||||
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
||||||
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
||||||
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
||||||
|
printf(" --pooling {none,mean,cls}\n");
|
||||||
|
printf(" pooling type for embeddings, use model default if unspecified\n");
|
||||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||||
@ -2284,6 +2286,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
}
|
}
|
||||||
params.yarn_beta_slow = std::stof(argv[i]);
|
params.yarn_beta_slow = std::stof(argv[i]);
|
||||||
}
|
}
|
||||||
|
else if (arg == "--pooling")
|
||||||
|
{
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::string value(argv[i]);
|
||||||
|
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
||||||
|
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
||||||
|
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||||
|
else { invalid_param = true; break; }
|
||||||
|
}
|
||||||
else if (arg == "--threads" || arg == "-t")
|
else if (arg == "--threads" || arg == "-t")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
98
llama.cpp
98
llama.cpp
@ -6117,39 +6117,38 @@ struct llm_build_context {
|
|||||||
cb(inpL, "inp_norm", -1);
|
cb(inpL, "inp_norm", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cont(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_tokens, n_tokens, n_tokens*ggml_type_size(lctx.inp_KQ_mask->type), 0));
|
||||||
cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens]
|
cb(KQ_mask, "KQ_mask", -1); // [n_tokens, n_tokens]
|
||||||
|
|
||||||
// iterate layers
|
// iterate layers
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * cur = inpL;
|
struct ggml_tensor * cur = inpL;
|
||||||
|
|
||||||
|
struct ggml_tensor * Qcur;
|
||||||
|
struct ggml_tensor * Kcur;
|
||||||
|
struct ggml_tensor * Vcur;
|
||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
if (model.arch == LLM_ARCH_BERT) {
|
if (model.arch == LLM_ARCH_BERT) {
|
||||||
struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
|
Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
|
Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
|
Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
// seems like we just need to do this for Q?
|
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
||||||
cb(cur, "kqv_out", il);
|
|
||||||
} else {
|
} else {
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||||
cb(cur, "wqkv", il);
|
cb(cur, "wqkv", il);
|
||||||
|
|
||||||
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||||
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||||
|
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
@ -6168,13 +6167,41 @@ struct llm_build_context {
|
|||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
|
||||||
cb(cur, "kqv_out", il);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||||
|
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||||
|
|
||||||
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
||||||
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
||||||
|
cb(v, "v", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
||||||
|
cb(kqv, "kqv", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||||
|
cb(kqv_merged, "kqv_merged", il);
|
||||||
|
|
||||||
|
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
||||||
|
cb(cur, "kqv_merged_cont", il);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
|
||||||
|
if (model.layers[il].bo) {
|
||||||
|
cb(cur, "kqv_wo", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model.layers[il].bo) {
|
||||||
|
cur = ggml_add(ctx0, cur, model.layers[il].bo);
|
||||||
|
}
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
// re-add the layer input
|
// re-add the layer input
|
||||||
cur = ggml_add(ctx0, cur, inpL);
|
cur = ggml_add(ctx0, cur, inpL);
|
||||||
|
|
||||||
@ -7985,7 +8012,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
if (hparams.causal_attn) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
@ -8004,12 +8031,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
(hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
|
(hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
|
||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
f = 0;
|
f = 0.0f;
|
||||||
}
|
}
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
|
||||||
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||||
|
|
||||||
|
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||||
|
|
||||||
|
for (int h = 0; h < 1; ++h) {
|
||||||
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
float f = -INFINITY;
|
||||||
|
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||||
|
if (batch.seq_id[i][s] == seq_id) {
|
||||||
|
f = 0.0f;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hparams.need_kq_pos) {
|
if (hparams.need_kq_pos) {
|
||||||
@ -8056,6 +8108,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||||
|
|
||||||
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
||||||
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
@ -8174,6 +8227,8 @@ static int llama_decode_internal(
|
|||||||
batch.seq_id = seq_id_arr.data();
|
batch.seq_id = seq_id_arr.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// non-causal masks do not use the KV cache
|
||||||
|
if (hparams.causal_attn) {
|
||||||
llama_kv_cache_update(&lctx);
|
llama_kv_cache_update(&lctx);
|
||||||
|
|
||||||
// if we have enough unused cells before the current head ->
|
// if we have enough unused cells before the current head ->
|
||||||
@ -8191,6 +8246,7 @@ static int llama_decode_internal(
|
|||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||||
|
}
|
||||||
|
|
||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user