From c5650ed470f6264046431c8b5e43cbfe3e680d17 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Sep 2023 19:03:36 +0300 Subject: [PATCH] server : avoid context swaps by shifting the KV cache --- examples/server/server.cpp | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c5b1328d9..273eb36f4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -381,6 +381,10 @@ struct llama_server_context // compare the evaluated prompt with the new prompt n_past = common_part(embd, prompt_tokens); + + // since #3228 we now have to manually manage the KV cache + llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx); + embd = prompt_tokens; if (n_past == num_prompt_tokens) { @@ -411,19 +415,27 @@ struct llama_server_context if (embd.size() >= (size_t)params.n_ctx) { - // Reset context - const int n_left = (params.n_ctx - params.n_keep) / 2; + // Shift context + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) + { + embd[i - n_discard] = embd[i]; + } + embd.resize(embd.size() - n_discard); + + n_past -= n_discard; - std::vector new_tokens(embd.begin(), embd.begin() + params.n_keep); - new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); - embd = new_tokens; - n_past = params.n_keep; truncated = true; LOG_VERBOSE("input truncated", { {"n_ctx", params.n_ctx}, {"n_keep", params.n_keep}, {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, }); } @@ -435,9 +447,6 @@ struct llama_server_context n_eval = params.n_batch; } - // since #3228 we now have to manually manage the KV cache - llama_kv_cache_tokens_rm(ctx, n_past, -1); - if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) { LOG_ERROR("failed to eval", {