From 43d8d4bf9e88df10203f7d8d4a1107b84bebbcfd Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 10 Jun 2024 14:44:42 -0400 Subject: [PATCH] examples : replace llama_kv_cache_seq_* with llama_past_seq_* --- common/common.cpp | 2 +- examples/batched-bench/batched-bench.cpp | 4 +- examples/batched.swift/Sources/main.swift | 2 +- examples/batched/batched.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/gritlm/gritlm.cpp | 4 +- examples/imatrix/imatrix.cpp | 2 +- examples/infill/infill.cpp | 4 +- examples/llama-bench/llama-bench.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 8 ++-- .../llama.cpp.swift/LibLlama.swift | 8 ++-- examples/lookahead/lookahead.cpp | 13 ++--- examples/lookup/lookup.cpp | 3 +- examples/main/main.cpp | 21 +++++---- examples/parallel/parallel.cpp | 10 ++-- examples/passkey/passkey.cpp | 28 +++++------ examples/perplexity/perplexity.cpp | 12 ++--- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 2 +- examples/server/server.cpp | 47 ++++++++++--------- examples/speculative/speculative.cpp | 26 +++++----- llama.cpp | 3 +- llama.h | 28 +++++------ 23 files changed, 125 insertions(+), 112 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 1591790e6..d04e04741 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2366,7 +2366,7 @@ std::tuple llama_init_from_gpt_par std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_kv_cache_clear(lctx); + llama_past_clear(lctx); llama_synchronize(lctx); llama_reset_timings(lctx); } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 718f0a61a..114dd811e 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_TEE("%s: llama_decode() failed\n", __func__); @@ -162,7 +162,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } } diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index dbbd06da5..443a03d57 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) + llama_past_seq_cp(context, 0, Int32(i), -1, -1) } if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 62d9b144d..888cf9e8e 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -112,7 +112,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them for (int32_t i = 1; i < n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } if (n_parallel > 1) { diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 244751e00..9a7c32d6b 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -25,7 +25,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 213515791..dd389ac00 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -43,7 +43,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); llama_set_causal_attn(ctx, false); // run model @@ -97,7 +97,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo const llama_model * mdl = llama_get_model(ctx); llama_token eos_token = llama_token_eos(mdl); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); llama_set_causal_attn(ctx, true); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index e18f49563..c81590a3f 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -455,7 +455,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 0e4ec79c6..0a74b93ab 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -380,8 +380,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 5c31548a6..d48eb245d 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1360,7 +1360,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // warmup run if (t.n_prompt > 0) { @@ -1372,7 +1372,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); uint64_t t_start = get_time_ns(); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 874158ef0..57ee5a650 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( } batch->logits[batch->n_tokens - 1] = true; - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("Benchmark text generation (tg)"); - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_end = ggml_time_us(); - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_cache_clear(reinterpret_cast(context)); + llama_past_clear(reinterpret_cast(context)); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 737f882fb..50fcaa12d 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -214,7 +214,7 @@ actor LlamaContext { } batch.logits[Int(batch.n_tokens) - 1] = 1 // true - llama_kv_cache_clear(context) + llama_past_clear(context) let t_pp_start = ggml_time_us() @@ -227,7 +227,7 @@ actor LlamaContext { // bench text generation - llama_kv_cache_clear(context) + llama_past_clear(context) let t_tg_start = ggml_time_us() @@ -246,7 +246,7 @@ actor LlamaContext { let t_tg_end = ggml_time_us() - llama_kv_cache_clear(context) + llama_past_clear(context) let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -296,7 +296,7 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() - llama_kv_cache_clear(context) + llama_past_clear(context) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index fb20ad93f..7f6e42e8d 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -96,7 +96,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_past_seq_cp(ctx, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -438,17 +438,18 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation - llama_kv_cache_seq_rm(ctx, -1, n_past, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_kv_cache_seq_keep(ctx, seq_id_best); - llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); + llama_past_seq_keep(ctx, seq_id_best); + llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1); + llama_past_seq_rm (ctx, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_past_seq_cp(ctx, 0, s, -1, -1); } } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 80ecd925d..db861d6ad 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -195,7 +195,8 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx, 0, n_past, -1); llama_batch_clear(batch_tgt); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b97b7b793..446fe035c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -299,6 +299,10 @@ int main(int argc, char ** argv) { } n_matching_session_tokens++; } + + // remove any "future" tokens that we might have inherited from the previous session + n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1); + if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { LOG_TEE("%s: using full prompt from session file\n", __func__); } else if (n_matching_session_tokens >= embd_inp.size()) { @@ -310,9 +314,6 @@ int main(int argc, char ** argv) { LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", __func__, n_matching_session_tokens, embd_inp.size()); } - - // remove any "future" tokens that we might have inherited from the previous session - llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } LOGLN( @@ -325,6 +326,8 @@ int main(int argc, char ** argv) { LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1); + } else { + session_tokens.resize(n_matching_session_tokens); } // number of tokens to keep when resetting context @@ -535,8 +538,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); + llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); n_past -= n_discard; @@ -563,9 +566,9 @@ int main(int argc, char ** argv) { LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd); + llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; @@ -579,6 +582,8 @@ int main(int argc, char ** argv) { if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { + // TODO: are the session tokens guaranteed to all be matching here? + // Should n_matching_session_tokens be re-used instead? if (embd[i] != session_tokens[n_session_consumed]) { session_tokens.resize(n_session_consumed); break; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7faeaec97..f68478804 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -200,7 +200,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("\n"); @@ -232,9 +232,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_rm(ctx, i, -1, -1); + llama_past_seq_rm(ctx, i, -1, -1); // but keep the system prompt - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("%s: clearing the KV cache\n", __func__); @@ -371,8 +371,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); - llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_past_seq_rm(ctx, client.id + 1, -1, -1); + llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1..c6564c5cf 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -126,11 +126,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_cache_update (ctx); + llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; } llama_batch_clear(batch); @@ -160,12 +160,12 @@ int main(int argc, char ** argv) { LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag(ctx); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; llama_batch_clear(batch); @@ -191,12 +191,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag(ctx); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; } } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 0bd78c21a..ad03b3bb5 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -575,7 +575,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -944,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1221,7 +1221,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1594,7 +1594,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1780,7 +1780,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { } // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 55b7b2f70..bd7d06d37 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -81,7 +81,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 00c2277ac..974dc3c3e 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -192,7 +192,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_kv_cache_clear(ctx3); + llama_past_clear(ctx3); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6ffaa8d9f..a04c47bae 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1107,7 +1107,7 @@ struct server_context { LOG_VERBOSE("clearing KV cache", {}); // clear the entire KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); clean_kv_cache = false; } @@ -1151,7 +1151,7 @@ struct server_context { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } } @@ -1824,7 +1824,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + llama_past_seq_rm(ctx, slot->id + 1, -1, -1); slot->cache_tokens.clear(); server_task_result result; @@ -1939,8 +1939,8 @@ struct server_context { {"n_cache_tokens", slot.cache_tokens.size()} }); - llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); + llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -2155,23 +2155,28 @@ struct server_context { } // keep only the common part - int p0 = (int) system_tokens.size() + slot.n_past; - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); + llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past; - p0 = (int) system_tokens.size(); - if (p0 != 0) { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); + // for recurrent and hybrid models, sometimes it goes back further than asked + llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1); + + if (new_p0 < p0) { + GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size()); + + slot.n_past -= p0 - new_p0; + if (slot.ga_i > 0) { + // TODO: test with an hybrid model (e.g. Jamba) + slot.n_past_se -= p0 - new_p0; } - // there is no common part left (except for the system prompt) - slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - // TODO: is the system prompt ever in the sampling context? + // TODO: find a way to avoid rolling back the sampling context twice llama_sampling_reset(slot.ctx_sampling); + // push the prompt into the sampling context (do not apply grammar) + for (int i = 0; i < slot.n_past; ++i) { + llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + } + + p0 = new_p0; } // remove the non-common part from the cache @@ -2273,9 +2278,9 @@ struct server_context { LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); + llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); + llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); slot.n_past_se -= bd; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0939a1a6a..3a1ef06a5 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -394,14 +394,15 @@ int main(int argc, char ** argv) { { LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_kv_cache_seq_keep(ctx_dft, s_keep); - llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_dft, 0); + llama_past_seq_keep(ctx_dft, s_keep); + llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1); + llama_past_seq_keep(ctx_dft, 0); - llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_kv_cache_seq_keep(ctx_tgt, s_keep); - llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_tgt, 0); + // FIXME: recurrent and hybrid models + llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); + llama_past_seq_keep(ctx_tgt, s_keep); + llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1); + llama_past_seq_keep(ctx_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -418,7 +419,8 @@ int main(int argc, char ** argv) { llama_batch_clear(batch_dft); llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); @@ -474,8 +476,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { LOG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1); + llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -553,9 +555,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_past_seq_keep(ctx_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_past_seq_cp(ctx_tgt, 0, s, -1, -1); } // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); diff --git a/llama.cpp b/llama.cpp index 4b84313cf..2233161d8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2126,7 +2126,6 @@ struct llama_ubatch { llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] - // FIXME: make all uses of this use n_seqs int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] @@ -18992,7 +18991,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; diff --git a/llama.h b/llama.h index 0d9d52256..4ecfc5f3e 100644 --- a/llama.h +++ b/llama.h @@ -583,11 +583,11 @@ extern "C" { LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); // Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed - LLAMA_API void llama_cache_clear( + LLAMA_API void llama_past_clear( struct llama_context * ctx); LLAMA_API DEPRECATED(void llama_kv_cache_clear( struct llama_context * ctx), - "use llama_cache_clear instead"); + "use llama_past_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // seq_id < 0 : match any sequence @@ -595,7 +595,7 @@ extern "C" { // p1 < 0 : [p0, inf) // Returns n_past (one more than the largest remaining pos in the seq_id) // which is only meaningful to handle for partial removals. - LLAMA_API llama_pos llama_cache_seq_rm( + LLAMA_API llama_pos llama_past_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -605,7 +605,7 @@ extern "C" { llama_seq_id seq_id, llama_pos p0, llama_pos p1), - "use llama_cache_seq_rm instead, and handle its return value for partial removals"); + "use llama_past_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence @@ -613,7 +613,7 @@ extern "C" { // p1 < 0 : [p0, inf) // Returns n_past (one more than the largest remaining pos in the destination seq_id) // which is only meaningful to handle when partially copying. - LLAMA_API llama_pos llama_cache_seq_cp( + LLAMA_API llama_pos llama_past_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, @@ -625,16 +625,16 @@ extern "C" { llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1), - "use llama_cache_seq_cp instead, and handle its return value for partial copies"); + "use llama_past_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_cache_seq_keep( + LLAMA_API void llama_past_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_keep instead"); + "use llama_past_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -642,7 +642,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_cache_seq_add( + LLAMA_API void llama_past_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -654,7 +654,7 @@ extern "C" { llama_pos p0, llama_pos p1, llama_pos delta), - "use llama_cache_seq_add instead"); + "use llama_past_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -662,7 +662,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_cache_seq_div( + LLAMA_API void llama_past_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -674,16 +674,16 @@ extern "C" { llama_pos p0, llama_pos p1, int d), - "use llama_cache_seq_div instead"); + "use llama_past_seq_div instead"); // Returns the largest position present in the KV and/or RS cache for the specified sequence - LLAMA_API llama_pos llama_cache_seq_pos_max( + LLAMA_API llama_pos llama_past_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); + "use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: