From 0f332a91042dd11f565bc5d04ae68727a72a41aa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Oct 2023 16:42:14 +0300 Subject: [PATCH] llama : temp fix for clearing "future" tokens from the KV cache --- llama.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/llama.cpp b/llama.cpp index f2f89ee1e..c8fb97702 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7478,6 +7478,25 @@ void llama_batch_free(struct llama_batch batch) { int llama_decode( struct llama_context * ctx, struct llama_batch batch) { + // TODO: temporary solution to auto clear "future" tokens from the cache + // ref: https://github.com/ggerganov/llama.cpp/pull/3400 + if (batch.pos) { + std::map seq_min_pos; + for (int i = 0; i < batch.n_tokens; i++) { + if (seq_min_pos.count(batch.seq_id[i]) == 0) { + seq_min_pos[batch.seq_id[i]] = batch.pos[i]; + } else { + seq_min_pos[batch.seq_id[i]] = std::min(seq_min_pos[batch.seq_id[i]], batch.pos[i]); + } + } + + for (auto & kv : seq_min_pos) { + llama_kv_cache_seq_rm(ctx->kv_self, kv.first, kv.second, ctx->cparams.n_ctx); + } + } else { + llama_kv_cache_seq_rm(ctx->kv_self, batch.all_seq_id, batch.all_pos_0, ctx->cparams.n_ctx); + } + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);