mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 23:00:46 +01:00
server : avoid context swaps by shifting the KV cache
This commit is contained in:
parent
ce2d995af2
commit
c5650ed470
@ -381,6 +381,10 @@ struct llama_server_context
|
|||||||
|
|
||||||
// compare the evaluated prompt with the new prompt
|
// compare the evaluated prompt with the new prompt
|
||||||
n_past = common_part(embd, prompt_tokens);
|
n_past = common_part(embd, prompt_tokens);
|
||||||
|
|
||||||
|
// since #3228 we now have to manually manage the KV cache
|
||||||
|
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
|
||||||
|
|
||||||
embd = prompt_tokens;
|
embd = prompt_tokens;
|
||||||
if (n_past == num_prompt_tokens)
|
if (n_past == num_prompt_tokens)
|
||||||
{
|
{
|
||||||
@ -411,19 +415,27 @@ struct llama_server_context
|
|||||||
|
|
||||||
if (embd.size() >= (size_t)params.n_ctx)
|
if (embd.size() >= (size_t)params.n_ctx)
|
||||||
{
|
{
|
||||||
// Reset context
|
// Shift context
|
||||||
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
|
||||||
|
const int n_left = n_past - params.n_keep - 1;
|
||||||
|
const int n_discard = n_left/2;
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||||
|
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||||
|
|
||||||
|
for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
|
||||||
|
{
|
||||||
|
embd[i - n_discard] = embd[i];
|
||||||
|
}
|
||||||
|
embd.resize(embd.size() - n_discard);
|
||||||
|
|
||||||
|
n_past -= n_discard;
|
||||||
|
|
||||||
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
|
|
||||||
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
|
|
||||||
embd = new_tokens;
|
|
||||||
n_past = params.n_keep;
|
|
||||||
truncated = true;
|
truncated = true;
|
||||||
LOG_VERBOSE("input truncated", {
|
LOG_VERBOSE("input truncated", {
|
||||||
{"n_ctx", params.n_ctx},
|
{"n_ctx", params.n_ctx},
|
||||||
{"n_keep", params.n_keep},
|
{"n_keep", params.n_keep},
|
||||||
{"n_left", n_left},
|
{"n_left", n_left},
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,9 +447,6 @@ struct llama_server_context
|
|||||||
n_eval = params.n_batch;
|
n_eval = params.n_batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
// since #3228 we now have to manually manage the KV cache
|
|
||||||
llama_kv_cache_tokens_rm(ctx, n_past, -1);
|
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
|
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
|
||||||
{
|
{
|
||||||
LOG_ERROR("failed to eval", {
|
LOG_ERROR("failed to eval", {
|
||||||
|
Loading…
Reference in New Issue
Block a user