From 6952a460b9857d4d603853042e7b417c802d4137 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 18 Sep 2023 15:31:24 +0300 Subject: [PATCH] llama : add cell_max heuristic for more efficient kv_cache --- llama.cpp | 119 +++++++++++++++++++++++++++++++++++++++++------------- llama.h | 12 +++++- 2 files changed, 102 insertions(+), 29 deletions(-) diff --git a/llama.cpp b/llama.cpp index 601f557ef..4867d348f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1023,6 +1023,9 @@ struct llama_kv_cache { uint32_t head = 0; uint32_t size = 0; + // largest index of an occupied cell (used for a basic optimization heuristic) + uint32_t cell_max = 0; + std::vector cells; struct ggml_tensor * k = NULL; @@ -1226,6 +1229,8 @@ static bool llama_kv_cache_init( cache.head = 0; cache.size = n_ctx; + cache.cell_max = 0; + cache.cells.clear(); cache.cells.resize(n_ctx); @@ -1311,6 +1316,16 @@ static bool llama_kv_cache_find_slot( return true; } +void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) { + cache.cell_max = 0; + + for (uint32_t i = 0; i < cache.size; i++) { + if (cache.cells[i].pos >= 0) { + cache.cell_max = i + 1; + } + } +} + void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) { cache.head = p0; @@ -1321,6 +1336,8 @@ void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } + + llama_kv_cache_update_cell_max(cache); } // @@ -2547,6 +2564,7 @@ static struct ggml_cgraph * llm_build_llama( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = kv_self.cell_max + n_tokens; auto & buf_compute = lctx.buf_compute; @@ -2621,7 +2639,7 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { float * data = (float *) KQ_mask->data; @@ -2629,9 +2647,19 @@ static struct ggml_cgraph * llm_build_llama( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_ctx; ++i) { - if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) { - data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY; + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + + // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation + for (int i = n_kv; i < n_ctx; ++i) { + if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) { + GGML_ASSERT(false && "cell_max is too small - this might indicate a bug"); } } } @@ -2725,7 +2753,7 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_ctx, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -2738,7 +2766,7 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_ctx, n_tokens, n_head, 1] + // KQ_scaled shape [n_kv, n_tokens, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); @@ -2756,7 +2784,7 @@ static struct ggml_cgraph * llm_build_llama( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_ctx, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -2901,6 +2929,7 @@ static struct ggml_cgraph * llm_build_baichaun( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = kv_self.cell_max + n_tokens; auto & buf_compute = lctx.buf_compute; @@ -2975,7 +3004,7 @@ static struct ggml_cgraph * llm_build_baichaun( ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { float * data = (float *) KQ_mask->data; @@ -2983,9 +3012,19 @@ static struct ggml_cgraph * llm_build_baichaun( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_ctx; ++i) { - if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) { - data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY; + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + + // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation + for (int i = n_kv; i < n_ctx; ++i) { + if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) { + GGML_ASSERT(false && "cell_max is too small - this might indicate a bug"); } } } @@ -3092,7 +3131,7 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_ctx, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3135,7 +3174,7 @@ static struct ggml_cgraph * llm_build_baichaun( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_ctx, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3272,6 +3311,7 @@ static struct ggml_cgraph * llm_build_falcon( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = kv_self.cell_max + n_tokens; auto & buf_compute = lctx.buf_compute; @@ -3346,7 +3386,7 @@ static struct ggml_cgraph * llm_build_falcon( ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { float * data = (float *) KQ_mask->data; @@ -3354,9 +3394,19 @@ static struct ggml_cgraph * llm_build_falcon( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_ctx; ++i) { - if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) { - data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY; + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + + // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation + for (int i = n_kv; i < n_ctx; ++i) { + if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) { + GGML_ASSERT(false && "cell_max is too small - this might indicate a bug"); } } } @@ -3479,7 +3529,7 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_ctx, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3504,7 +3554,7 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_ctx, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3598,6 +3648,7 @@ static struct ggml_cgraph * llm_build_starcoder( const float norm_eps = hparams.f_norm_eps; const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = kv_self.cell_max + n_tokens; auto & buf_compute = lctx.buf_compute; @@ -3664,7 +3715,7 @@ static struct ggml_cgraph * llm_build_starcoder( ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { float * data = (float *) KQ_mask->data; @@ -3672,9 +3723,19 @@ static struct ggml_cgraph * llm_build_starcoder( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_ctx; ++i) { - if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) { - data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY; + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + + // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation + for (int i = n_kv; i < n_ctx; ++i) { + if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) { + GGML_ASSERT(false && "cell_max is too small - this might indicate a bug"); } } } @@ -3727,7 +3788,7 @@ static struct ggml_cgraph * llm_build_starcoder( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_ctx, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3753,7 +3814,7 @@ static struct ggml_cgraph * llm_build_starcoder( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_ctx, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3974,8 +4035,9 @@ static bool llama_eval_internal( ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif - // update the kv ring buffer head - lctx.kv_self.head += n_tokens; + // update the kv ring buffer + lctx.kv_self.head += n_tokens; + lctx.kv_self.cell_max = std::max(lctx.kv_self.cell_max, lctx.kv_self.head); #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -7040,6 +7102,9 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { } ctx->kv_self.head = kv_ntok; + ctx->kv_self.size = kv_size; + + ctx->kv_self.cell_max = kv_ntok; } const size_t nread = inp - src; diff --git a/llama.h b/llama.h index b844e172b..ae7ac5e3d 100644 --- a/llama.h +++ b/llama.h @@ -316,15 +316,19 @@ extern "C" { int n_threads); // - // KV cache API + // KV cache // // Returns the number of tokens in the KV cache LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), - "avoid using this, it will be removed in the future"); + "avoid using this, it will be removed in the future, instead - count the tokens in user code"); LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1); + // + // State / sessions + // + // Returns the maximum size in bytes of the state (rng, logits, embedding // and kv_cache) - will often be smaller after compacting tokens LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); @@ -342,6 +346,10 @@ extern "C" { LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); + // + // Decoding + // + // Run the llama inference to obtain the logits and probabilities for the next token. // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls