diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 0659ab6f1..fcbad37bb 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -57,6 +57,8 @@ int main(int argc, char ** argv) { return 1; } + llama_kv_cache * kv = llama_get_kv_cache(ctx); + const int32_t n_kv_max = llama_n_ctx(ctx); llama_batch batch = llama_batch_init(n_kv_max, 0, 1); @@ -132,7 +134,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_ERR("%s: llama_decode() failed\n", __func__); @@ -141,7 +143,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(kv, 0, i, -1, -1); } } diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 413b71d34..adb4a60ad 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -342,7 +342,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { } static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { - llama_kv_cache_clear(ctx); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + llama_kv_cache_clear(kv); if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 72eb46257..16437453e 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -13,6 +13,8 @@ static std::vector> encode(llama_context * ctx, const std::ve const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); for (uint64_t i = 0; i < sentences.size(); i++) { @@ -45,7 +47,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); llama_set_embeddings(ctx, true); llama_set_causal_attn(ctx, false); @@ -100,9 +102,11 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + llama_token eos_token = llama_vocab_eos(vocab); - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index b5f3feb9f..5efe4f019 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -431,6 +431,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + const bool add_bos = llama_vocab_get_add_bos(vocab); const int n_ctx = llama_n_ctx(ctx); @@ -497,7 +499,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); llama_batch batch = llama_batch_init(n_batch, 0, 1); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 489a208b6..de8e77695 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -139,6 +139,8 @@ int main(int argc, char ** argv) { return 1; } + llama_kv_cache * kv = llama_get_kv_cache(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); const int n_ctx_train = llama_model_n_ctx_train(model); @@ -332,8 +334,8 @@ int main(int argc, char ** argv) { LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (kv, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_add(kv, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 4ac19ca86..8843c0048 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1575,9 +1575,11 @@ int main(int argc, char ** argv) { return 1; } + llama_kv_cache * kv = llama_get_kv_cache(ctx); + test t(inst, lmodel, ctx); - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); // cool off before the test if (params.delay) { @@ -1617,7 +1619,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); uint64_t t_start = get_time_ns(); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 2f0898e62..1219c2074 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -60,6 +60,7 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + llama_kv_cache * kv = llama_get_kv_cache(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -95,7 +96,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_kv_cache_seq_cp(kv, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -437,17 +438,17 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation - llama_kv_cache_seq_rm(ctx, -1, n_past, -1); + llama_kv_cache_seq_rm(kv, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_kv_cache_seq_keep(ctx, seq_id_best); - llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); + llama_kv_cache_seq_keep(kv, seq_id_best); + llama_kv_cache_seq_cp (kv, seq_id_best, 0, -1, -1); + llama_kv_cache_seq_rm (kv, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_kv_cache_seq_cp(kv, 0, s, -1, -1); } } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index dbd0444ec..8628f7318 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -35,6 +35,7 @@ int main(int argc, char ** argv){ llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + llama_kv_cache * kv = llama_get_kv_cache(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -192,7 +193,7 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + llama_kv_cache_seq_rm(kv, 0, n_past, -1); common_batch_clear(batch_tgt); common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 39666a0e8..02fd61762 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -164,6 +164,8 @@ int main(int argc, char ** argv) { return 1; } + llama_kv_cache * kv = llama_get_kv_cache(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -326,7 +328,7 @@ int main(int argc, char ** argv) { } // remove any "future" tokens that we might have inherited from the previous session - llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); + llama_kv_cache_seq_rm(kv, -1, n_matching_session_tokens, -1); } LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", @@ -567,8 +569,8 @@ int main(int argc, char ** argv) { LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (kv, 0, params.n_keep , params.n_keep + n_discard); + llama_kv_cache_seq_add(kv, 0, params.n_keep + n_discard, n_past, -n_discard); n_past -= n_discard; @@ -591,9 +593,9 @@ int main(int argc, char ** argv) { LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_kv_cache_seq_add(kv, 0, ga_i, n_past, ib*bd); + llama_kv_cache_seq_div(kv, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_kv_cache_seq_add(kv, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7ef43d5e1..2ba0706dc 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -134,6 +134,7 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + llama_kv_cache * kv = llama_get_kv_cache(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -201,7 +202,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(kv, 0, i, -1, -1); } LOG_INF("\n"); @@ -233,9 +234,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_rm(ctx, i, -1, -1); + llama_kv_cache_seq_rm(kv, i, -1, -1); // but keep the system prompt - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(kv, 0, i, -1, -1); } LOG_INF("%s: clearing the KV cache\n", __func__); @@ -371,8 +372,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); - llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_kv_cache_seq_rm(kv, client.id + 1, -1, -1); + llama_kv_cache_seq_cp(kv, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 5953928d4..e2764313b 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -86,6 +86,8 @@ int main(int argc, char ** argv) { return 1; } + llama_kv_cache * kv = llama_get_kv_cache(ctx); + auto sparams = llama_sampler_chain_default_params(); llama_sampler * smpl = llama_sampler_chain_init(sparams); @@ -132,11 +134,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_cache_update (ctx); + llama_kv_cache_seq_add(kv, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div(kv, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_update_kv_cache (ctx, kv); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1; } common_batch_clear(batch); @@ -166,12 +168,12 @@ int main(int argc, char ** argv) { LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_kv_cache_seq_rm (kv, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(kv, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag (kv); + llama_update_kv_cache (ctx, kv); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1; common_batch_clear(batch); @@ -197,12 +199,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_kv_cache_seq_rm (kv, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(kv, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag (kv); + llama_update_kv_cache (ctx, kv); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1; } } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9bf6c5743..6c9f716ed 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -299,6 +299,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + const bool add_bos = llama_vocab_get_add_bos(vocab); GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); @@ -360,7 +362,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); llama_batch batch = llama_batch_init(n_batch, 0, 1); @@ -450,6 +452,8 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + const bool add_bos = llama_vocab_get_add_bos(vocab); GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); @@ -546,7 +550,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -741,6 +745,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + // Calculates hellaswag score (acc_norm) from prompt // // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl @@ -923,7 +929,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1084,6 +1090,8 @@ static void winogrande_score(llama_context * ctx, const common_params & params) const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + constexpr int k_min_trailing_ctx = 3; auto data = load_winogrande_from_csv(params.prompt); @@ -1202,7 +1210,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) return; } - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1388,6 +1396,8 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + std::istringstream strstream(params.prompt); uint32_t n_task; strstream.read((char *)&n_task, sizeof(n_task)); @@ -1574,7 +1584,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par return; } - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1671,6 +1681,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); + llama_kv_cache * kv = llama_get_kv_cache(ctx); + if (params.logits_file.empty()) { LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__); return; @@ -1764,7 +1776,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } // clear the KV cache - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); llama_batch batch = llama_batch_init(n_batch, 0, 1); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 2439022a2..a907ea076 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -82,8 +82,10 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { + llama_kv_cache * kv = llama_get_kv_cache(ctx); + // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index dd9ea79e8..535c0ae67 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -743,8 +743,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt // Check if we have enough space in the context to evaluate this batch static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { + llama_kv_cache * kv = llama_get_kv_cache(ctx.get()); + const int n_ctx = llama_n_ctx(ctx.get()); - const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); + const int n_ctx_used = llama_kv_cache_used_cells(kv); if (n_ctx_used + batch.n_tokens > n_ctx) { printf("\033[0m\n"); printe("context size exceeded\n"); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index cf7cbd815..3839fbe8c 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -156,6 +156,8 @@ int main(int argc, char ** argv) { // make new context llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); + llama_kv_cache * kv3 = llama_get_kv_cache(ctx3); + llama_sampler * smpl3 = llama_sampler_chain_init(sparams); llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed)); @@ -196,7 +198,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_kv_cache_clear(ctx3); + llama_kv_cache_clear(kv3); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d1e8ee829..f049c7eaf 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1661,6 +1661,7 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; + llama_kv_cache * kv = nullptr; const llama_vocab * vocab = nullptr; @@ -1721,6 +1722,8 @@ struct server_context { return false; } + kv = llama_get_kv_cache(ctx); + vocab = llama_model_get_vocab(model); n_ctx = llama_n_ctx(ctx); @@ -1958,7 +1961,7 @@ struct server_context { SRV_DBG("%s", "clearing KV cache\n"); // clear the entire KV cache - llama_kv_cache_clear(ctx); + llama_kv_cache_clear(kv); clean_kv_cache = false; } @@ -2500,8 +2503,8 @@ struct server_context { res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); res->t_start = metrics.t_start; - res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); - res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); + res->kv_cache_tokens_count = llama_kv_cache_n_tokens(kv); + res->kv_cache_used_cells = llama_kv_cache_used_cells(kv); res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; res->t_prompt_processing_total = metrics.t_prompt_processing_total; @@ -2617,7 +2620,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); + llama_kv_cache_seq_rm(kv, slot->id, -1, -1); slot->cache_tokens.clear(); auto res = std::make_unique(); @@ -2685,8 +2688,8 @@ struct server_context { SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + llama_kv_cache_seq_rm (kv, slot.id, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(kv, slot.id, n_keep + n_discard, slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -2873,8 +2876,8 @@ struct server_context { const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c); - llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift); + llama_kv_cache_seq_rm (kv, slot.id, head_p, head_c); + llama_kv_cache_seq_add(kv, slot.id, head_c, -1, kv_shift); for (size_t i = 0; i < n_match; i++) { slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; @@ -2912,9 +2915,9 @@ struct server_context { } // keep only the common part - if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { + if (!llama_kv_cache_seq_rm(kv, slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); + llama_kv_cache_seq_rm(kv, slot.id, -1, -1); // there is no common part left slot.n_past = 0; @@ -3154,7 +3157,7 @@ struct server_context { slot.cache_tokens.push_back(id); slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); - llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + llama_kv_cache_seq_rm(kv, slot.id, slot.n_past, -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 26422601d..51894070f 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -88,6 +88,8 @@ int main(int argc, char ** argv) { return 1; } + llama_kv_cache * kv = llama_get_kv_cache(ctx); + // initialize the sampler llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); @@ -101,7 +103,7 @@ int main(int argc, char ** argv) { // tokenize the prompt const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); std::vector prompt_tokens(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_kv_cache_used_cells(kv) == 0, true) < 0) { GGML_ABORT("failed to tokenize the prompt\n"); } @@ -111,7 +113,7 @@ int main(int argc, char ** argv) { while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); - int n_ctx_used = llama_get_kv_cache_used_cells(ctx); + int n_ctx_used = llama_kv_cache_used_cells(kv); if (n_ctx_used + batch.n_tokens > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 403ba2dd2..24bdc806d 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -45,6 +45,8 @@ int main(int argc, char ** argv) { model_tgt = llama_init_tgt.model.get(); ctx_tgt = llama_init_tgt.context.get(); + llama_kv_cache * kv = llama_get_kv_cache(ctx_tgt); + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model @@ -217,7 +219,7 @@ int main(int argc, char ** argv) { { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); + llama_kv_cache_seq_rm(kv, 0, n_past, -1); } if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index c7ccea50d..b4e5259b5 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -90,6 +90,9 @@ int main(int argc, char ** argv) { model_dft = llama_init_dft.model.get(); ctx_dft = llama_init_dft.context.get(); + llama_kv_cache * kv_tgt = llama_get_kv_cache(ctx_tgt); + llama_kv_cache * kv_dft = llama_get_kv_cache(ctx_dft); + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); @@ -420,14 +423,14 @@ int main(int argc, char ** argv) { { LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_kv_cache_seq_keep(ctx_dft, s_keep); - llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_dft, 0); + llama_kv_cache_seq_keep(kv_dft, s_keep); + llama_kv_cache_seq_cp (kv_dft, s_keep, 0, -1, -1); + llama_kv_cache_seq_keep(kv_dft, 0); - llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_kv_cache_seq_keep(ctx_tgt, s_keep); - llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_kv_cache_seq_rm (kv_tgt, s_keep, n_past_tgt, -1); + llama_kv_cache_seq_keep(kv_tgt, s_keep); + llama_kv_cache_seq_cp (kv_tgt, s_keep, 0, -1, -1); + llama_kv_cache_seq_keep(kv_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -444,8 +447,8 @@ int main(int argc, char ** argv) { common_batch_clear(batch_dft); common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); - // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); + llama_kv_cache_seq_rm(kv_dft, 0, n_past_dft, -1); + // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(kv_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); ++n_past_dft; @@ -503,8 +506,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) { LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_kv_cache_seq_rm(kv_dft, n_seq_cur, -1, -1); + llama_kv_cache_seq_cp(kv_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -585,9 +588,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_kv_cache_seq_keep(kv_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_kv_cache_seq_cp(kv_tgt, 0, s, -1, -1); } // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());