From f015b266892b1634ac5a738c5861ea6976848ab4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 18 Sep 2023 17:15:25 +0300 Subject: [PATCH] llama : more robust cell_max heuristic + wip shift --- examples/llama-bench/llama-bench.cpp | 4 ++ llama.cpp | 81 +++++++++++----------------- llama.h | 6 ++- 3 files changed, 39 insertions(+), 52 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2551f8422..8fdbd8033 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -977,6 +977,8 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); + llama_kv_cache_keep_seq(ctx, -1); + // warmup run if (t.n_prompt > 0) { test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads); @@ -986,6 +988,8 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { + llama_kv_cache_keep_seq(ctx, -1); + uint64_t t_start = get_time_ns(); if (t.n_prompt > 0) { test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); diff --git a/llama.cpp b/llama.cpp index ce7ea408b..1ef615811 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1023,9 +1023,6 @@ struct llama_kv_cache { uint32_t head = 0; uint32_t size = 0; - // largest index of an occupied cell (used for a basic optimization heuristic) - uint32_t cell_max = 0; - std::vector cells; struct ggml_tensor * k = NULL; @@ -1229,8 +1226,6 @@ static bool llama_kv_cache_init( cache.head = 0; cache.size = n_ctx; - cache.cell_max = 0; - cache.cells.clear(); cache.cells.resize(n_ctx); @@ -1316,15 +1311,16 @@ static bool llama_kv_cache_find_slot( return true; } -void llama_kv_cache_update(struct llama_kv_cache & cache) { - // compute new cell_max - cache.cell_max = 0; +int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { + int32_t res = 0; for (uint32_t i = 0; i < cache.size; i++) { - if (cache.cells[i].pos >= 0) { - cache.cell_max = i + 1; + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { + res = i + 1; } } + + return res; } void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { @@ -1335,8 +1331,6 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } - - llama_kv_cache_update(cache); } void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) { @@ -1348,8 +1342,6 @@ void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) { } } } - - llama_kv_cache_update(cache); } void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) { @@ -1359,8 +1351,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) cache.cells[i].seq_id.clear(); } } +} - llama_kv_cache_update(cache); +void llama_kv_cache_shift( + struct llama_context & ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + auto & hparams = ctx.model.hparams; + auto & cache = ctx.kv_self; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].pos += delta; + } + } } // @@ -2587,7 +2593,7 @@ static struct ggml_cgraph * llm_build_llama( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = kv_self.cell_max + n_tokens; + const int32_t n_kv = llama_kv_cache_cell_max(kv_self); auto & buf_compute = lctx.buf_compute; @@ -2678,13 +2684,6 @@ static struct ggml_cgraph * llm_build_llama( data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; } } - - // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation - for (int i = n_kv; i < n_ctx; ++i) { - if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) { - GGML_ASSERT(false && "cell_max is too small - this might indicate a bug"); - } - } } } } @@ -2952,7 +2951,7 @@ static struct ggml_cgraph * llm_build_baichaun( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = kv_self.cell_max + n_tokens; + const int32_t n_kv = llama_kv_cache_cell_max(kv_self); auto & buf_compute = lctx.buf_compute; @@ -3043,13 +3042,6 @@ static struct ggml_cgraph * llm_build_baichaun( data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; } } - - // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation - for (int i = n_kv; i < n_ctx; ++i) { - if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) { - GGML_ASSERT(false && "cell_max is too small - this might indicate a bug"); - } - } } } } @@ -3334,7 +3326,7 @@ static struct ggml_cgraph * llm_build_falcon( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = kv_self.cell_max + n_tokens; + const int32_t n_kv = llama_kv_cache_cell_max(kv_self); auto & buf_compute = lctx.buf_compute; @@ -3425,13 +3417,6 @@ static struct ggml_cgraph * llm_build_falcon( data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; } } - - // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation - for (int i = n_kv; i < n_ctx; ++i) { - if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) { - GGML_ASSERT(false && "cell_max is too small - this might indicate a bug"); - } - } } } } @@ -3671,7 +3656,7 @@ static struct ggml_cgraph * llm_build_starcoder( const float norm_eps = hparams.f_norm_eps; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = kv_self.cell_max + n_tokens; + const int32_t n_kv = llama_kv_cache_cell_max(kv_self); auto & buf_compute = lctx.buf_compute; @@ -3754,13 +3739,6 @@ static struct ggml_cgraph * llm_build_starcoder( data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; } } - - // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation - for (int i = n_kv; i < n_ctx; ++i) { - if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) { - GGML_ASSERT(false && "cell_max is too small - this might indicate a bug"); - } - } } } } @@ -4055,8 +4033,7 @@ static bool llama_eval_internal( #endif // update the kv ring buffer - lctx.kv_self.head += n_tokens; - lctx.kv_self.cell_max = std::max(lctx.kv_self.cell_max, lctx.kv_self.head); + lctx.kv_self.head += n_tokens; #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -6834,6 +6811,10 @@ void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) { llama_kv_cache_keep_seq(ctx->kv_self, seq_id); } +void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_kv_cache_shift(*ctx, seq_id, p0, p1, delta); +} + // Returns the *maximum* size of the state size_t llama_get_state_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. @@ -7130,8 +7111,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ctx->kv_self.head = kv_ntok; ctx->kv_self.size = kv_size; - - ctx->kv_self.cell_max = kv_ntok; } const size_t nread = inp - src; diff --git a/llama.h b/llama.h index 4b70509b0..ec05fa6ea 100644 --- a/llama.h +++ b/llama.h @@ -321,7 +321,7 @@ extern "C" { LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), "avoid using this, it will be removed in the future, instead - count the tokens in user code"); - // Remove all tokens between cells [c0, c1) + // Remove all tokens data of cells in [c0, c1) LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1); // Removes all tokens that belong to the specified sequence @@ -330,6 +330,10 @@ extern "C" { // Removes all tokens that do not belong to the specified sequence LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id); + // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + // If the KV cache is RoPEd, the KV data is updated accordingly + LLAMA_API void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + // // State / sessions //