From 337120cc0d933bd6870bb7acfd880fade2d26b6c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Oct 2023 18:29:22 +0300 Subject: [PATCH] llama : fix handling of "future" tokens when loading sessions --- examples/main/main.cpp | 3 ++ examples/parallel/parallel.cpp | 2 +- examples/server/server.cpp | 2 +- examples/speculative/speculative.cpp | 6 +-- llama.cpp | 60 ++++++++++++---------------- llama.h | 8 ++++ 6 files changed, 41 insertions(+), 40 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3a4ed3f78..7367ae362 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -543,6 +543,9 @@ int main(int argc, char ** argv) { if (i > 0) { embd.erase(embd.begin(), embd.begin() + i); } + + // remove any "future" tokens that we might have inherited from the session from the KV cache + llama_kv_cache_tokens_rm(ctx, n_past, -1); } // evaluate tokens in batches diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 0434ded23..ffd7b1db4 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -332,7 +332,7 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx); + llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6dda5e36b..921eb5da4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -448,7 +448,7 @@ struct llama_server_context n_past = common_part(embd, prompt_tokens); // since #3228 we now have to manually manage the KV cache - llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx); + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); embd = prompt_tokens; if (n_past == num_prompt_tokens) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index c5e5b234f..75a2e5e22 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { LOG("out of drafted tokens\n"); } - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0)); ++n_past_dft; @@ -257,7 +257,7 @@ int main(int argc, char ** argv) { } // evaluate the drafted token on the draft model - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1); llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0)); ++n_past_cur; @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { } // evaluate the target model on the drafted tokens - llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx); + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1); llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0)); ++n_past_tgt; diff --git a/llama.cpp b/llama.cpp index c8fb97702..c9d5fc37a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1281,8 +1281,8 @@ static bool llama_kv_cache_init( // find an empty slot of size "n_tokens" in the cache // updates the cache head static bool llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_batch & batch) { + struct llama_kv_cache & cache, + const struct llama_batch & batch) { const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; @@ -1350,10 +1350,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, } static void llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].seq_id.erase(seq_id); @@ -1365,11 +1368,14 @@ static void llama_kv_cache_seq_rm( } static void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].seq_id.insert(seq_id_dst); @@ -1387,11 +1393,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id } static void llama_kv_cache_seq_shift( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].pos += delta; @@ -7478,25 +7487,6 @@ void llama_batch_free(struct llama_batch batch) { int llama_decode( struct llama_context * ctx, struct llama_batch batch) { - // TODO: temporary solution to auto clear "future" tokens from the cache - // ref: https://github.com/ggerganov/llama.cpp/pull/3400 - if (batch.pos) { - std::map seq_min_pos; - for (int i = 0; i < batch.n_tokens; i++) { - if (seq_min_pos.count(batch.seq_id[i]) == 0) { - seq_min_pos[batch.seq_id[i]] = batch.pos[i]; - } else { - seq_min_pos[batch.seq_id[i]] = std::min(seq_min_pos[batch.seq_id[i]], batch.pos[i]); - } - } - - for (auto & kv : seq_min_pos) { - llama_kv_cache_seq_rm(ctx->kv_self, kv.first, kv.second, ctx->cparams.n_ctx); - } - } else { - llama_kv_cache_seq_rm(ctx->kv_self, batch.all_seq_id, batch.all_pos_0, ctx->cparams.n_ctx); - } - const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); diff --git a/llama.h b/llama.h index 4564d570e..c4333d652 100644 --- a/llama.h +++ b/llama.h @@ -330,12 +330,16 @@ extern "C" { "avoid using this, it will be removed in the future, instead - count the tokens in user code"); // Remove all tokens data of cells in [c0, c1) + // c0 < -1 : [0, c1] + // c1 < -1 : [c0, inf) LLAMA_API void llama_kv_cache_tokens_rm( struct llama_context * ctx, int32_t c0, int32_t c1); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // p0 < -1 : [0, p1] + // p1 < -1 : [p0, inf) LLAMA_API void llama_kv_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, @@ -344,6 +348,8 @@ extern "C" { // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // p0 < -1 : [0, p1] + // p1 < -1 : [p0, inf) LLAMA_API void llama_kv_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, @@ -358,6 +364,8 @@ extern "C" { // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly + // p0 < -1 : [0, p1] + // p1 < -1 : [p0, inf) LLAMA_API void llama_kv_cache_seq_shift( struct llama_context * ctx, llama_seq_id seq_id,