examples : replace llama_kv_cache_seq_* with llama_past_seq_*

This commit is contained in:
Francis Couture-Harpin 2024-06-10 14:44:42 -04:00
parent 372482dffe
commit 43d8d4bf9e
23 changed files with 125 additions and 112 deletions

View File

@ -2366,7 +2366,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), }; std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx); llama_past_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);
llama_reset_timings(lctx); llama_reset_timings(lctx);
} }

View File

@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
const auto t_pp_start = ggml_time_us(); const auto t_pp_start = ggml_time_us();
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) { if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__); LOG_TEE("%s: llama_decode() failed\n", __func__);
@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
if (is_pp_shared) { if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) { for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
} }

View File

@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 {
} }
for i in 1 ..< n_parallel { for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) llama_past_seq_cp(context, 0, Int32(i), -1, -1)
} }
if n_parallel > 1 { if n_parallel > 1 {

View File

@ -112,7 +112,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences // assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (int32_t i = 1; i < n_parallel; ++i) { for (int32_t i = 1; i < n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
if (n_parallel > 1) { if (n_parallel > 1) {

View File

@ -25,7 +25,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// run model // run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View File

@ -43,7 +43,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
} }
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
llama_set_causal_attn(ctx, false); llama_set_causal_attn(ctx, false);
// run model // run model
@ -97,7 +97,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
const llama_model * mdl = llama_get_model(ctx); const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl); llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
llama_set_causal_attn(ctx, true); llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);

View File

@ -455,7 +455,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;

View File

@ -380,8 +380,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard); n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard; n_past -= n_discard;

View File

@ -1360,7 +1360,7 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx); test t(inst, lmodel, ctx);
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// warmup run // warmup run
if (t.n_prompt > 0) { if (t.n_prompt > 0) {
@ -1372,7 +1372,7 @@ int main(int argc, char ** argv) {
} }
for (int i = 0; i < params.reps; i++) { for (int i = 0; i < params.reps; i++) {
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
uint64_t t_start = get_time_ns(); uint64_t t_start = get_time_ns();

View File

@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
} }
batch->logits[batch->n_tokens - 1] = true; batch->logits[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context); llama_past_clear(context);
const auto t_pp_start = ggml_time_us(); const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) { if (llama_decode(context, *batch) != 0) {
@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
LOGi("Benchmark text generation (tg)"); LOGi("Benchmark text generation (tg)");
llama_kv_cache_clear(context); llama_past_clear(context);
const auto t_tg_start = ggml_time_us(); const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) { for (i = 0; i < tg; i++) {
@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto t_tg_end = ggml_time_us(); const auto t_tg_end = ggml_time_us();
llama_kv_cache_clear(context); llama_past_clear(context);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context)); llama_past_clear(reinterpret_cast<llama_context *>(context));
} }

View File

@ -214,7 +214,7 @@ actor LlamaContext {
} }
batch.logits[Int(batch.n_tokens) - 1] = 1 // true batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context) llama_past_clear(context)
let t_pp_start = ggml_time_us() let t_pp_start = ggml_time_us()
@ -227,7 +227,7 @@ actor LlamaContext {
// bench text generation // bench text generation
llama_kv_cache_clear(context) llama_past_clear(context)
let t_tg_start = ggml_time_us() let t_tg_start = ggml_time_us()
@ -246,7 +246,7 @@ actor LlamaContext {
let t_tg_end = ggml_time_us() let t_tg_end = ggml_time_us()
llama_kv_cache_clear(context) llama_past_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@ -296,7 +296,7 @@ actor LlamaContext {
func clear() { func clear() {
tokens_list.removeAll() tokens_list.removeAll()
temporary_invalid_cchars.removeAll() temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context) llama_past_clear(context)
} }
private func tokenize(text: String, add_bos: Bool) -> [llama_token] { private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

View File

@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
for (int s = 1; s < W + G + 1; ++s) { for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); llama_past_seq_cp(ctx, 0, s, -1, -1);
} }
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
@ -438,17 +438,18 @@ int main(int argc, char ** argv) {
// KV cache management // KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(ctx, -1, n_past, -1); // FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) { if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest // if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation // this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(ctx, seq_id_best); llama_past_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); llama_past_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) { for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); llama_past_seq_cp(ctx, 0, s, -1, -1);
} }
} }
} }

View File

@ -195,7 +195,8 @@ int main(int argc, char ** argv){
// KV cache management // KV cache management
// clean the cache of draft tokens that weren't accepted // clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1); // FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, 0, n_past, -1);
llama_batch_clear(batch_tgt); llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

View File

@ -299,6 +299,10 @@ int main(int argc, char ** argv) {
} }
n_matching_session_tokens++; n_matching_session_tokens++;
} }
// remove any "future" tokens that we might have inherited from the previous session
n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1);
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
LOG_TEE("%s: using full prompt from session file\n", __func__); LOG_TEE("%s: using full prompt from session file\n", __func__);
} else if (n_matching_session_tokens >= embd_inp.size()) { } else if (n_matching_session_tokens >= embd_inp.size()) {
@ -310,9 +314,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size()); __func__, n_matching_session_tokens, embd_inp.size());
} }
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
} }
LOGLN( LOGLN(
@ -325,6 +326,8 @@ int main(int argc, char ** argv) {
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
session_tokens.resize(embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1);
} else {
session_tokens.resize(n_matching_session_tokens);
} }
// number of tokens to keep when resetting context // number of tokens to keep when resetting context
@ -535,8 +538,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard); n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard; n_past -= n_discard;
@ -563,9 +566,9 @@ int main(int argc, char ** argv) {
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
n_past -= bd; n_past -= bd;
@ -579,6 +582,8 @@ int main(int argc, char ** argv) {
if (n_session_consumed < (int) session_tokens.size()) { if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0; size_t i = 0;
for ( ; i < embd.size(); i++) { for ( ; i < embd.size(); i++) {
// TODO: are the session tokens guaranteed to all be matching here?
// Should n_matching_session_tokens be re-used instead?
if (embd[i] != session_tokens[n_session_consumed]) { if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed); session_tokens.resize(n_session_consumed);
break; break;

View File

@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences // assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= n_clients; ++i) { for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
LOG_TEE("\n"); LOG_TEE("\n");
@ -232,9 +232,9 @@ int main(int argc, char ** argv) {
if (batch.n_tokens == 0) { if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache // all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) { for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, -1, -1); llama_past_seq_rm(ctx, i, -1, -1);
// but keep the system prompt // but keep the system prompt
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
LOG_TEE("%s: clearing the KV cache\n", __func__); LOG_TEE("%s: clearing the KV cache\n", __func__);
@ -371,8 +371,8 @@ int main(int argc, char ** argv) {
} }
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); llama_past_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();

View File

@ -126,11 +126,11 @@ int main(int argc, char ** argv) {
const int ib = i/n_batch - 1; const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1); const int bd = n_batch_grp*(n_grp - 1);
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_update(ctx); llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_past_seq_pos_max(ctx, 0) + 1;
} }
llama_batch_clear(batch); llama_batch_clear(batch);
@ -160,12 +160,12 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag(ctx); //llama_kv_cache_defrag(ctx);
llama_kv_cache_update(ctx); llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_past_seq_pos_max(ctx, 0) + 1;
llama_batch_clear(batch); llama_batch_clear(batch);
@ -191,12 +191,12 @@ int main(int argc, char ** argv) {
if (n_discard > 0) { if (n_discard > 0) {
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag(ctx); //llama_kv_cache_defrag(ctx);
llama_kv_cache_update(ctx); llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_past_seq_pos_max(ctx, 0) + 1;
} }
} }

View File

@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
@ -575,7 +575,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
@ -944,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
return; return;
} }
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1221,7 +1221,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
return; return;
} }
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1594,7 +1594,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
return; return;
} }
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1780,7 +1780,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
} }
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;

View File

@ -81,7 +81,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// run model // run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View File

@ -192,7 +192,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
// erase whole kv // erase whole kv
llama_kv_cache_clear(ctx3); llama_past_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__); fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1 // restore kv into seq 1

View File

@ -1107,7 +1107,7 @@ struct server_context {
LOG_VERBOSE("clearing KV cache", {}); LOG_VERBOSE("clearing KV cache", {});
// clear the entire KV cache // clear the entire KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
clean_kv_cache = false; clean_kv_cache = false;
} }
@ -1151,7 +1151,7 @@ struct server_context {
// assign the system KV cache to all parallel sequences // assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params.n_parallel; ++i) { for (int32_t i = 1; i <= params.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
} }
@ -1824,7 +1824,7 @@ struct server_context {
// Erase token cache // Erase token cache
const size_t n_erased = slot->cache_tokens.size(); const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); llama_past_seq_rm(ctx, slot->id + 1, -1, -1);
slot->cache_tokens.clear(); slot->cache_tokens.clear();
server_task_result result; server_task_result result;
@ -1939,8 +1939,8 @@ struct server_context {
{"n_cache_tokens", slot.cache_tokens.size()} {"n_cache_tokens", slot.cache_tokens.size()}
}); });
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@ -2155,23 +2155,28 @@ struct server_context {
} }
// keep only the common part // keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past; llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
p0 = (int) system_tokens.size(); // for recurrent and hybrid models, sometimes it goes back further than asked
if (p0 != 0) { llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1);
// copy over the system prompt when there is one
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); if (new_p0 < p0) {
GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size());
slot.n_past -= p0 - new_p0;
if (slot.ga_i > 0) {
// TODO: test with an hybrid model (e.g. Jamba)
slot.n_past_se -= p0 - new_p0;
} }
// there is no common part left (except for the system prompt) // TODO: find a way to avoid rolling back the sampling context twice
slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
llama_sampling_reset(slot.ctx_sampling); llama_sampling_reset(slot.ctx_sampling);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
}
p0 = new_p0;
} }
// remove the non-common part from the cache // remove the non-common part from the cache
@ -2273,9 +2278,9 @@ struct server_context {
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd; slot.n_past_se -= bd;

View File

@ -394,14 +394,15 @@ int main(int argc, char ** argv) {
{ {
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(ctx_dft, s_keep); llama_past_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0); llama_past_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); // FIXME: recurrent and hybrid models
llama_kv_cache_seq_keep(ctx_tgt, s_keep); llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); llama_past_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_keep(ctx_tgt, 0); llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_past_seq_keep(ctx_tgt, 0);
} }
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
@ -418,7 +419,8 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch_dft); llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); // FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft); llama_decode(ctx_dft, batch_dft);
@ -474,8 +476,8 @@ int main(int argc, char ** argv) {
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur); LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch // all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) { for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@ -553,9 +555,9 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens // evaluate the target model on the drafted tokens
{ {
llama_kv_cache_seq_keep(ctx_tgt, 0); llama_past_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) { for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); llama_past_seq_cp(ctx_tgt, 0, s, -1, -1);
} }
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());

View File

@ -2126,7 +2126,6 @@ struct llama_ubatch {
llama_token * token; // [n_tokens] llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens] float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens] llama_pos * pos; // [n_tokens]
// FIXME: make all uses of this use n_seqs
int32_t * n_seq_id; // [n_seqs] int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs]
int8_t * output; // [n_tokens] int8_t * output; // [n_tokens]
@ -18992,7 +18991,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
__func__, kv_head, kv_size, kv_self.size); __func__, kv_head, kv_size, kv_self.size);
} }
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
if (kv_buf_size) { if (kv_buf_size) {
const size_t pre_kv_buf_size = inp - src; const size_t pre_kv_buf_size = inp - src;

28
llama.h
View File

@ -583,11 +583,11 @@ extern "C" {
LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed // Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed
LLAMA_API void llama_cache_clear( LLAMA_API void llama_past_clear(
struct llama_context * ctx); struct llama_context * ctx);
LLAMA_API DEPRECATED(void llama_kv_cache_clear( LLAMA_API DEPRECATED(void llama_kv_cache_clear(
struct llama_context * ctx), struct llama_context * ctx),
"use llama_cache_clear instead"); "use llama_past_clear instead");
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// seq_id < 0 : match any sequence // seq_id < 0 : match any sequence
@ -595,7 +595,7 @@ extern "C" {
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
// Returns n_past (one more than the largest remaining pos in the seq_id) // Returns n_past (one more than the largest remaining pos in the seq_id)
// which is only meaningful to handle for partial removals. // which is only meaningful to handle for partial removals.
LLAMA_API llama_pos llama_cache_seq_rm( LLAMA_API llama_pos llama_past_seq_rm(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
@ -605,7 +605,7 @@ extern "C" {
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1), llama_pos p1),
"use llama_cache_seq_rm instead, and handle its return value for partial removals"); "use llama_past_seq_rm instead, and handle its return value for partial removals");
// Copy all tokens that belong to the specified sequence to another sequence // Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
@ -613,7 +613,7 @@ extern "C" {
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
// Returns n_past (one more than the largest remaining pos in the destination seq_id) // Returns n_past (one more than the largest remaining pos in the destination seq_id)
// which is only meaningful to handle when partially copying. // which is only meaningful to handle when partially copying.
LLAMA_API llama_pos llama_cache_seq_cp( LLAMA_API llama_pos llama_past_seq_cp(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id_src, llama_seq_id seq_id_src,
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
@ -625,16 +625,16 @@ extern "C" {
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1), llama_pos p1),
"use llama_cache_seq_cp instead, and handle its return value for partial copies"); "use llama_past_seq_cp instead, and handle its return value for partial copies");
// Removes all tokens that do not belong to the specified sequence // Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_cache_seq_keep( LLAMA_API void llama_past_seq_keep(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id), llama_seq_id seq_id),
"use llama_cache_seq_keep instead"); "use llama_past_seq_keep instead");
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
@ -642,7 +642,7 @@ extern "C" {
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_cache_seq_add( LLAMA_API void llama_past_seq_add(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
@ -654,7 +654,7 @@ extern "C" {
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta), llama_pos delta),
"use llama_cache_seq_add instead"); "use llama_past_seq_add instead");
// Integer division of the positions by factor of `d > 1` // Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
@ -662,7 +662,7 @@ extern "C" {
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_cache_seq_div( LLAMA_API void llama_past_seq_div(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
@ -674,16 +674,16 @@ extern "C" {
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d), int d),
"use llama_cache_seq_div instead"); "use llama_past_seq_div instead");
// Returns the largest position present in the KV and/or RS cache for the specified sequence // Returns the largest position present in the KV and/or RS cache for the specified sequence
LLAMA_API llama_pos llama_cache_seq_pos_max( LLAMA_API llama_pos llama_past_seq_pos_max(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id), llama_seq_id seq_id),
"use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); "use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied: