From 4d76d762ef6d0292506007e6721a9f2b8bd52861 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 18 Sep 2023 15:53:03 +0300 Subject: [PATCH] llama : extend llama_kv_cache API --- examples/embd-input/embd-input-lib.cpp | 2 +- examples/perplexity/perplexity.cpp | 41 ++++++++++++----- llama.cpp | 62 ++++++++++++++++++-------- llama.h | 11 +++-- 4 files changed, 84 insertions(+), 32 deletions(-) diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index ed0966a51..344a8b2c3 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){ if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false }; + llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, }; if (llama_decode(ctx, batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 2a046d55e..fd2160bbf 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -79,7 +79,9 @@ static void write_logfile( static std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); float max_logit = logits[0]; - for (float v : logits) max_logit = std::max(max_logit, v); + for (float v : logits) { + max_logit = std::max(max_logit, v); + } double sum_exp = 0.0; for (size_t i = 0; i < logits.size(); i++) { // Subtract the maximum logit value from the current logit value for numerical stability @@ -88,15 +90,21 @@ static std::vector softmax(const std::vector& logits) { sum_exp += exp_logit; probs[i] = exp_logit; } - for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; + for (size_t i = 0; i < probs.size(); i++) { + probs[i] /= sum_exp; + } return probs; } static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { float max_logit = logits[0]; - for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]); + for (int i = 1; i < n_vocab; ++i) { + max_logit = std::max(max_logit, logits[i]); + } double sum_exp = 0.0; - for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit); + for (int i = 0; i < n_vocab; ++i) { + sum_exp += expf(logits[i] - max_logit); + } return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; } @@ -107,7 +115,8 @@ static void process_logits( std::mutex mutex; int counter = 0; auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { - double local_nll = 0, local_nll2 = 0; + double local_nll = 0; + double local_nll2 = 0; while (true) { std::unique_lock lock(mutex); int i = counter++; @@ -125,10 +134,13 @@ static void process_logits( prob_history[i] = results.prob; } }; - for (auto & w : workers) w = std::thread(compute); + for (auto & w : workers) { + w = std::thread(compute); + } compute(); - for (auto & w : workers) w.join(); - + for (auto & w : workers) { + w.join(); + } } static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { @@ -151,8 +163,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {std::move(tokens), 0., {}, {}}; } - std::vector logit_history; - std::vector prob_history; + std::vector logit_history; + std::vector prob_history; logit_history.resize(tokens.size()); prob_history.resize(tokens.size()); @@ -194,6 +206,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); + // clear the KV cache + llama_kv_cache_keep_seq(ctx, -1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -319,6 +334,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); + // clear the KV cache + llama_kv_cache_keep_seq(ctx, -1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -549,6 +567,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { query_embd.resize(32); } + // clear the KV cache + llama_kv_cache_keep_seq(ctx, -1); + auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); diff --git a/llama.cpp b/llama.cpp index 4867d348f..ce7ea408b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1316,7 +1316,8 @@ static bool llama_kv_cache_find_slot( return true; } -void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) { +void llama_kv_cache_update(struct llama_kv_cache & cache) { + // compute new cell_max cache.cell_max = 0; for (uint32_t i = 0; i < cache.size; i++) { @@ -1326,18 +1327,40 @@ void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) { } } -void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) { - cache.head = p0; +void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { + if (c0 < 0) c0 = 0; + if (c1 < 0) c1 = cache.size; - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = cache.size; - - for (int32_t i = p0; i < p1; ++i) { + for (int32_t i = c0; i < c1; ++i) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } - llama_kv_cache_update_cell_max(cache); + llama_kv_cache_update(cache); +} + +void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + if (cache.cells[i].seq_id.empty()) { + cache.cells[i].pos = -1; + } + } + } + + llama_kv_cache_update(cache); +} + +void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (!cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } + } + + llama_kv_cache_update(cache); } // @@ -3968,10 +3991,6 @@ static bool llama_eval_internal( batch.seq_id = seq_id.data(); } - if (batch.clear_kv) { - llama_kv_cache_clear(kv_self, 0, -1); - } - if (!llama_kv_cache_find_slot(kv_self, batch)) { return false; } @@ -6803,8 +6822,16 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) { return ctx->kv_self.head; } -void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) { - llama_kv_cache_clear(ctx->kv_self, p0, p1); +void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1) { + llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1); +} + +void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) { + llama_kv_cache_rm_seq(ctx->kv_self, seq_id); +} + +void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) { + llama_kv_cache_keep_seq(ctx->kv_self, seq_id); } // Returns the *maximum* size of the state @@ -7203,7 +7230,7 @@ int llama_eval( uint32_t n_tokens, int n_past, int n_threads) { - llama_kv_cache_clear(ctx->kv_self, n_past, -1); + llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1); if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); @@ -7226,9 +7253,9 @@ int llama_eval_embd( uint32_t n_tokens, int n_past, int n_threads) { - llama_kv_cache_clear(ctx->kv_self, n_past, -1); + llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1); - llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, }; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, }; if (!llama_eval_internal(*ctx, batch, n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); @@ -7259,7 +7286,6 @@ struct llama_batch llama_batch_get_one( /*all_pos_0 =*/ pos_0, /*all_pos_1 =*/ 1, /*all_seq_id =*/ seq_id, - /*clear_kv =*/ pos_0 == 0, }; } diff --git a/llama.h b/llama.h index ae7ac5e3d..4b70509b0 100644 --- a/llama.h +++ b/llama.h @@ -84,8 +84,6 @@ extern "C" { llama_pos all_pos_0; // used if pos == NULL llama_pos all_pos_1; // used if pos == NULL llama_seq_id all_seq_id; // used if seq_id == NULL - - bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations } llama_seq; enum llama_log_level { @@ -323,7 +321,14 @@ extern "C" { LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), "avoid using this, it will be removed in the future, instead - count the tokens in user code"); - LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1); + // Remove all tokens between cells [c0, c1) + LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1); + + // Removes all tokens that belong to the specified sequence + LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id); + + // Removes all tokens that do not belong to the specified sequence + LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id); // // State / sessions