mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
llama : refactor k-shift implementation + KV defragmentation (#5691)
* llama : refactor k-shift implementation ggml-ci * llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add * llama : cont k-shift refactoring + normalize type names ggml-ci * minor : fix MPI builds * llama : reuse n_rot from the build context ggml-ci * llama : revert enum name changes from this PR ggml-ci * llama : update llama_rope_type * llama : add comment about rope values * llama : fix build * passkey : apply kv cache updates explicitly ggml-ci * llama : change name to llama_kv_cache_update() * llama : add llama_kv_cache_seq_pos_max() * passkey : fix llama_kv_cache_seq_pos_max() usage * llama : some llama_kv_cell simplifications * llama : add llama_kv_cache_compress (EXPERIMENTAL) * llama : add alternative KV cache merging (EXPERIMENTAL) * llama : add llama_kv_cache_defrag * llama : comments * llama : remove llama_kv_cache_compress will add in a separate PR ggml-ci * llama : defragment via non-overlapping moves * llama : ggml_graph based defrag implementation ggml-ci * llama : switch the loop order in build_defrag * llama : add comments
This commit is contained in:
parent
f7625019c5
commit
bf08e00643
@ -448,7 +448,7 @@ int main(int argc, char ** argv) {
|
|||||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||||
|
|
||||||
n_past -= n_discard;
|
n_past -= n_discard;
|
||||||
|
|
||||||
|
@ -549,7 +549,7 @@ int main(int argc, char ** argv) {
|
|||||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||||
|
|
||||||
n_past -= n_discard;
|
n_past -= n_discard;
|
||||||
|
|
||||||
@ -576,9 +576,9 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
|
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
|
||||||
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
|
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
|
||||||
|
|
||||||
llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd);
|
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
|
||||||
llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
||||||
llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
||||||
|
|
||||||
n_past -= bd;
|
n_past -= bd;
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ int main(int argc, char ** argv) {
|
|||||||
const int n_batch = ctx_params.n_batch;
|
const int n_batch = ctx_params.n_batch;
|
||||||
const int n_batch_grp = ctx_params.n_batch/n_grp;
|
const int n_batch_grp = ctx_params.n_batch/n_grp;
|
||||||
|
|
||||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
|
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);
|
||||||
|
|
||||||
// print the prompt token-by-token
|
// print the prompt token-by-token
|
||||||
|
|
||||||
@ -146,10 +146,11 @@ int main(int argc, char ** argv) {
|
|||||||
const int ib = i/n_batch - 1;
|
const int ib = i/n_batch - 1;
|
||||||
const int bd = n_batch_grp*(n_grp - 1);
|
const int bd = n_batch_grp*(n_grp - 1);
|
||||||
|
|
||||||
llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
|
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||||
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||||
|
llama_kv_cache_update (ctx);
|
||||||
|
|
||||||
n_past -= bd;
|
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
@ -180,9 +181,11 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
|
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||||
|
llama_kv_cache_defrag (ctx);
|
||||||
|
llama_kv_cache_update (ctx);
|
||||||
|
|
||||||
n_past -= n_discard;
|
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
@ -209,9 +212,11 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
|
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||||
|
llama_kv_cache_defrag (ctx);
|
||||||
|
llama_kv_cache_update (ctx);
|
||||||
|
|
||||||
n_past -= n_discard;
|
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1637,7 +1637,7 @@ struct llama_server_context
|
|||||||
{"n_cache_tokens", slot.cache_tokens.size()}
|
{"n_cache_tokens", slot.cache_tokens.size()}
|
||||||
});
|
});
|
||||||
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_shift(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||||
|
|
||||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
|
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
|
||||||
{
|
{
|
||||||
@ -1941,9 +1941,9 @@ struct llama_server_context
|
|||||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||||
|
|
||||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
||||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
||||||
|
|
||||||
slot.n_past_se -= bd;
|
slot.n_past_se -= bd;
|
||||||
|
|
||||||
|
34
llama.h
34
llama.h
@ -64,6 +64,15 @@ extern "C" {
|
|||||||
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
|
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// note: these values should be synchronized with ggml_rope
|
||||||
|
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
||||||
|
enum llama_rope_type {
|
||||||
|
LLAMA_ROPE_TYPE_NONE = -1,
|
||||||
|
LLAMA_ROPE_TYPE_NORM = 0,
|
||||||
|
LLAMA_ROPE_TYPE_NEOX = 2,
|
||||||
|
LLAMA_ROPE_TYPE_GLM = 4,
|
||||||
|
};
|
||||||
|
|
||||||
enum llama_token_type {
|
enum llama_token_type {
|
||||||
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
|
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
|
||||||
LLAMA_TOKEN_TYPE_NORMAL = 1,
|
LLAMA_TOKEN_TYPE_NORMAL = 1,
|
||||||
@ -360,6 +369,7 @@ extern "C" {
|
|||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
|
|
||||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||||
|
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||||
|
|
||||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||||
@ -514,10 +524,12 @@ extern "C" {
|
|||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
|
// - lazily on next llama_decode()
|
||||||
|
// - explicitly with llama_kv_cache_update()
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_shift(
|
LLAMA_API void llama_kv_cache_seq_add(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
@ -525,7 +537,9 @@ extern "C" {
|
|||||||
llama_pos delta);
|
llama_pos delta);
|
||||||
|
|
||||||
// Integer division of the positions by factor of `d > 1`
|
// Integer division of the positions by factor of `d > 1`
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
|
// - lazily on next llama_decode()
|
||||||
|
// - explicitly with llama_kv_cache_update()
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_div(
|
LLAMA_API void llama_kv_cache_seq_div(
|
||||||
@ -535,6 +549,20 @@ extern "C" {
|
|||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
int d);
|
int d);
|
||||||
|
|
||||||
|
// Returns the largest position present in the KV cache for the specified sequence
|
||||||
|
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Defragment the KV cache
|
||||||
|
// This will be applied:
|
||||||
|
// - lazily on next llama_decode()
|
||||||
|
// - explicitly with llama_kv_cache_update()
|
||||||
|
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||||
|
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
||||||
|
|
||||||
//
|
//
|
||||||
// State / sessions
|
// State / sessions
|
||||||
//
|
//
|
||||||
|
Loading…
x
Reference in New Issue
Block a user