server : avoid context swaps by shifting the KV cache

This commit is contained in:
Georgi Gerganov 2023-09-28 19:03:36 +03:00
parent ce2d995af2
commit c5650ed470
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -381,6 +381,10 @@ struct llama_server_context
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens); n_past = common_part(embd, prompt_tokens);
// since #3228 we now have to manually manage the KV cache
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
embd = prompt_tokens; embd = prompt_tokens;
if (n_past == num_prompt_tokens) if (n_past == num_prompt_tokens)
{ {
@ -411,19 +415,27 @@ struct llama_server_context
if (embd.size() >= (size_t)params.n_ctx) if (embd.size() >= (size_t)params.n_ctx)
{ {
// Reset context // Shift context
const int n_left = (params.n_ctx - params.n_keep) / 2;
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
{
embd[i - n_discard] = embd[i];
}
embd.resize(embd.size() - n_discard);
n_past -= n_discard;
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = params.n_keep;
truncated = true; truncated = true;
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx}, {"n_ctx", params.n_ctx},
{"n_keep", params.n_keep}, {"n_keep", params.n_keep},
{"n_left", n_left}, {"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
}); });
} }
@ -435,9 +447,6 @@ struct llama_server_context
n_eval = params.n_batch; n_eval = params.n_batch;
} }
// since #3228 we now have to manually manage the KV cache
llama_kv_cache_tokens_rm(ctx, n_past, -1);
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
{ {
LOG_ERROR("failed to eval", { LOG_ERROR("failed to eval", {