mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
llama : store non-RoPEd K cache (WIP)
This commit is contained in:
parent
fad56936d4
commit
784d14ed31
42
llama.cpp
42
llama.cpp
@ -2428,16 +2428,25 @@ static struct ggml_cgraph * llm_build_llama(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// KQ_pos - contains the positions
|
// Q_pos - contains the positions
|
||||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * Q_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
ggml_allocr_alloc(lctx.alloc, KQ_pos);
|
ggml_allocr_alloc(lctx.alloc, Q_pos);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
int * data = (int *) KQ_pos->data;
|
int * data = (int *) Q_pos->data;
|
||||||
for (int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
data[i] = n_past + i;
|
data[i] = n_past + i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * K_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_past + N);
|
||||||
|
ggml_allocr_alloc(lctx.alloc, K_pos);
|
||||||
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
|
int * data = (int *) K_pos->data;
|
||||||
|
for (int i = 0; i < n_past + N; ++i) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
ggml_format_name(inpL, "layer_inp_%d", il);
|
ggml_format_name(inpL, "layer_inp_%d", il);
|
||||||
|
|
||||||
@ -2474,14 +2483,18 @@ static struct ggml_cgraph * llm_build_llama(
|
|||||||
offload_func_kq(tmpq);
|
offload_func_kq(tmpq);
|
||||||
ggml_set_name(tmpq, "tmpq");
|
ggml_set_name(tmpq, "tmpq");
|
||||||
|
|
||||||
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
// Note: we are not RoPE-ing K here
|
||||||
|
struct ggml_tensor * Kcur = tmpk;
|
||||||
offload_func_kq(Kcur);
|
offload_func_kq(Kcur);
|
||||||
ggml_set_name(Kcur, "Kcur");
|
ggml_set_name(Kcur, "Kcur");
|
||||||
|
|
||||||
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), Q_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||||
offload_func_kq(Qcur);
|
offload_func_kq(Qcur);
|
||||||
ggml_set_name(Qcur, "Qcur");
|
ggml_set_name(Qcur, "Qcur");
|
||||||
|
|
||||||
|
struct ggml_tensor * ck;
|
||||||
|
struct ggml_tensor * cv;
|
||||||
|
|
||||||
// store key and value to memory
|
// store key and value to memory
|
||||||
{
|
{
|
||||||
// compute the transposed [N, n_embd] V matrix
|
// compute the transposed [N, n_embd] V matrix
|
||||||
@ -2504,9 +2517,11 @@ static struct ggml_cgraph * llm_build_llama(
|
|||||||
offload_func_v(v);
|
offload_func_v(v);
|
||||||
ggml_set_name(v, "v");
|
ggml_set_name(v, "v");
|
||||||
|
|
||||||
// important: storing RoPE-ed version of K in the KV cache!
|
ck = ggml_cpy(ctx0, Kcur, k);
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
cv = ggml_cpy(ctx0, Vcur, v);
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
|
||||||
|
ggml_build_forward_expand(gf, ck);
|
||||||
|
ggml_build_forward_expand(gf, cv);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||||
@ -2515,13 +2530,18 @@ static struct ggml_cgraph * llm_build_llama(
|
|||||||
|
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_view_3d(ctx0, kv_self.k,
|
ggml_view_3d(ctx0, kv_self.k,
|
||||||
n_embd_head, n_past + N, n_head_kv,
|
n_embd_head, n_head_kv, n_past + N,
|
||||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
|
||||||
ggml_element_size(kv_self.k)*n_embd_head,
|
ggml_element_size(kv_self.k)*n_embd_head,
|
||||||
|
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
|
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
|
||||||
offload_func_kq(K);
|
offload_func_kq(K);
|
||||||
ggml_set_name(K, "K");
|
ggml_set_name(K, "K");
|
||||||
|
|
||||||
|
// RoPE the K cache
|
||||||
|
K->src[1] = ck; // TODO: HACK!!
|
||||||
|
K = ggml_rope_custom(ctx0, K, K_pos, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||||
|
K = ggml_permute(ctx0, K, 0, 2, 1, 3);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
offload_func_kq(KQ);
|
offload_func_kq(KQ);
|
||||||
|
Loading…
Reference in New Issue
Block a user