llama : update llama_kv_self API

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-14 16:47:34 +02:00
parent 4ef4c96100
commit e7f2dc8bc4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
30 changed files with 386 additions and 203 deletions

View File

@ -909,9 +909,7 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
llama_kv_cache * kv = llama_get_kv_cache(lctx); if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
if (params.ctx_shift && !llama_kv_cache_can_shift(kv)) {
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__); LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
params.ctx_shift = false; params.ctx_shift = false;
} }
@ -1016,7 +1014,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_decoder(model)) { if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
} }
llama_kv_cache_clear(kv); llama_kv_self_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);
llama_perf_context_reset(lctx); llama_perf_context_reset(lctx);
} }

View File

@ -171,10 +171,8 @@ llama_tokens common_speculative_gen_draft(
llama_tokens result; llama_tokens result;
result.reserve(params.n_draft); result.reserve(params.n_draft);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
if (reuse_n == 0) { if (reuse_n == 0) {
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
prompt.clear(); prompt.clear();
} else { } else {
@ -193,14 +191,14 @@ llama_tokens common_speculative_gen_draft(
} }
if (reuse_i > 0) { if (reuse_i > 0) {
llama_kv_cache_seq_rm (kv, 0, 0, reuse_i); llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_add(kv, 0, reuse_i, -1, -reuse_i); llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i); prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
} }
if (reuse_n < (int) prompt.size()) { if (reuse_n < (int) prompt.size()) {
llama_kv_cache_seq_rm (kv, 0, reuse_n, -1); llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end()); prompt.erase(prompt.begin() + reuse_n, prompt.end());
} }

View File

@ -57,8 +57,6 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const int32_t n_kv_max = llama_n_ctx(ctx); const int32_t n_kv_max = llama_n_ctx(ctx);
llama_batch batch = llama_batch_init(n_kv_max, 0, 1); llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
@ -134,7 +132,7 @@ int main(int argc, char ** argv) {
const auto t_pp_start = ggml_time_us(); const auto t_pp_start = ggml_time_us();
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) { if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_ERR("%s: llama_decode() failed\n", __func__); LOG_ERR("%s: llama_decode() failed\n", __func__);
@ -143,7 +141,7 @@ int main(int argc, char ** argv) {
if (is_pp_shared) { if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) { for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(kv, 0, i, -1, -1); llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
} }
} }

View File

@ -111,7 +111,7 @@ if llama_decode(context, batch) != 0 {
} }
for i in 1 ..< n_parallel { for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
} }
if n_parallel > 1 { if n_parallel > 1 {

View File

@ -342,8 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
} }
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) { static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache * kv = llama_get_kv_cache(ctx); llama_kv_self_clear(ctx);
llama_kv_cache_clear(kv);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;

View File

@ -34,11 +34,10 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const llama_model * model = llama_get_model(ctx); const struct llama_model * model = llama_get_model(ctx);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
// run model // run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View File

@ -13,8 +13,6 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
for (uint64_t i = 0; i < sentences.size(); i++) { for (uint64_t i = 0; i < sentences.size(); i++) {
@ -47,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
} }
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
llama_set_embeddings(ctx, true); llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false); llama_set_causal_attn(ctx, false);
@ -102,11 +100,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
llama_token eos_token = llama_vocab_eos(vocab); llama_token eos_token = llama_vocab_eos(vocab);
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
llama_set_embeddings(ctx, false); llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true); llama_set_causal_attn(ctx, true);

View File

@ -431,8 +431,6 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const bool add_bos = llama_vocab_get_add_bos(vocab); const bool add_bos = llama_vocab_get_add_bos(vocab);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
@ -499,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1); llama_batch batch = llama_batch_init(n_batch, 0, 1);

View File

@ -139,8 +139,6 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
@ -334,8 +332,8 @@ int main(int argc, char ** argv) {
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard); n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (kv, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(kv, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard; n_past -= n_discard;

View File

@ -1575,11 +1575,9 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_kv_cache * kv = llama_get_kv_cache(ctx);
test t(inst, lmodel, ctx); test t(inst, lmodel, ctx);
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
// cool off before the test // cool off before the test
if (params.delay) { if (params.delay) {
@ -1619,7 +1617,7 @@ int main(int argc, char ** argv) {
} }
for (int i = 0; i < params.reps; i++) { for (int i = 0; i < params.reps; i++) {
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
uint64_t t_start = get_time_ns(); uint64_t t_start = get_time_ns();

View File

@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
} }
batch->logits[batch->n_tokens - 1] = true; batch->logits[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context); llama_kv_self_clear(context);
const auto t_pp_start = ggml_time_us(); const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) { if (llama_decode(context, *batch) != 0) {
@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
LOGi("Benchmark text generation (tg)"); LOGi("Benchmark text generation (tg)");
llama_kv_cache_clear(context); llama_kv_self_clear(context);
const auto t_tg_start = ggml_time_us(); const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) { for (i = 0; i < tg; i++) {
@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto t_tg_end = ggml_time_us(); const auto t_tg_end = ggml_time_us();
llama_kv_cache_clear(context); llama_kv_self_clear(context);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context)); llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
} }

View File

@ -208,7 +208,7 @@ actor LlamaContext {
} }
batch.logits[Int(batch.n_tokens) - 1] = 1 // true batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context) llama_kv_self_clear(context)
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000; let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
@ -221,7 +221,7 @@ actor LlamaContext {
// bench text generation // bench text generation
llama_kv_cache_clear(context) llama_kv_self_clear(context)
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000; let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
@ -240,7 +240,7 @@ actor LlamaContext {
let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000; let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
llama_kv_cache_clear(context) llama_kv_self_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@ -290,7 +290,7 @@ actor LlamaContext {
func clear() { func clear() {
tokens_list.removeAll() tokens_list.removeAll()
temporary_invalid_cchars.removeAll() temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context) llama_kv_self_clear(context)
} }
private func tokenize(text: String, add_bos: Bool) -> [llama_token] { private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

View File

@ -60,7 +60,6 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
@ -96,7 +95,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
for (int s = 1; s < W + G + 1; ++s) { for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(kv, 0, s, -1, -1); llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
} }
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
@ -438,17 +437,17 @@ int main(int argc, char ** argv) {
// KV cache management // KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(kv, -1, n_past, -1); llama_kv_self_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) { if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest // if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation // this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(kv, seq_id_best); llama_kv_self_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (kv, seq_id_best, 0, -1, -1); llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (kv, seq_id_best, -1, -1); llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) { for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(kv, 0, s, -1, -1); llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
} }
} }
} }

View File

@ -35,7 +35,6 @@ int main(int argc, char ** argv){
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
@ -193,7 +192,7 @@ int main(int argc, char ** argv){
// KV cache management // KV cache management
// clean the cache of draft tokens that weren't accepted // clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(kv, 0, n_past, -1); llama_kv_self_seq_rm(ctx, 0, n_past, -1);
common_batch_clear(batch_tgt); common_batch_clear(batch_tgt);
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

View File

@ -164,8 +164,6 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@ -328,7 +326,7 @@ int main(int argc, char ** argv) {
} }
// remove any "future" tokens that we might have inherited from the previous session // remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(kv, -1, n_matching_session_tokens, -1); llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
} }
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@ -569,8 +567,8 @@ int main(int argc, char ** argv) {
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard); n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (kv, 0, params.n_keep , params.n_keep + n_discard); llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(kv, 0, params.n_keep + n_discard, n_past, -n_discard); llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard; n_past -= n_discard;
@ -593,9 +591,9 @@ int main(int argc, char ** argv) {
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
llama_kv_cache_seq_add(kv, 0, ga_i, n_past, ib*bd); llama_kv_self_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(kv, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(kv, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
n_past -= bd; n_past -= bd;

View File

@ -134,7 +134,6 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
@ -202,7 +201,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences // assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= n_clients; ++i) { for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(kv, 0, i, -1, -1); llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
} }
LOG_INF("\n"); LOG_INF("\n");
@ -234,9 +233,9 @@ int main(int argc, char ** argv) {
if (batch.n_tokens == 0) { if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache // all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) { for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(kv, i, -1, -1); llama_kv_self_seq_rm(ctx, i, -1, -1);
// but keep the system prompt // but keep the system prompt
llama_kv_cache_seq_cp(kv, 0, i, -1, -1); llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
} }
LOG_INF("%s: clearing the KV cache\n", __func__); LOG_INF("%s: clearing the KV cache\n", __func__);
@ -372,8 +371,8 @@ int main(int argc, char ** argv) {
} }
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(kv, client.id + 1, -1, -1); llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(kv, 0, client.id + 1, -1, -1); llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();

View File

@ -86,8 +86,6 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_kv_cache * kv = llama_get_kv_cache(ctx);
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
llama_sampler * smpl = llama_sampler_chain_init(sparams); llama_sampler * smpl = llama_sampler_chain_init(sparams);
@ -134,11 +132,11 @@ int main(int argc, char ** argv) {
const int ib = i/n_batch - 1; const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1); const int bd = n_batch_grp*(n_grp - 1);
llama_kv_cache_seq_add(kv, 0, n_past - n_batch, n_past, ib*bd); llama_kv_self_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div(kv, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_update_kv_cache (ctx, kv); llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1; n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
} }
common_batch_clear(batch); common_batch_clear(batch);
@ -168,12 +166,12 @@ int main(int argc, char ** argv) {
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard); LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
llama_kv_cache_seq_rm (kv, 0, n_keep , n_keep + n_discard); llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(kv, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag (kv); //llama_kv_self_defrag (ctx);
llama_update_kv_cache (ctx, kv); llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1; n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
common_batch_clear(batch); common_batch_clear(batch);
@ -199,12 +197,12 @@ int main(int argc, char ** argv) {
if (n_discard > 0) { if (n_discard > 0) {
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
llama_kv_cache_seq_rm (kv, 0, n_keep , n_keep + n_discard); llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(kv, 0, n_keep + n_discard, n_ctx, -n_discard); llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag (kv); //llama_kv_self_defrag (ctx);
llama_update_kv_cache (ctx, kv); llama_kv_self_update (ctx);
n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1; n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
} }
} }

View File

@ -299,8 +299,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const bool add_bos = llama_vocab_get_add_bos(vocab); const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
@ -362,7 +360,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1); llama_batch batch = llama_batch_init(n_batch, 0, 1);
@ -452,8 +450,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const bool add_bos = llama_vocab_get_add_bos(vocab); const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
@ -550,7 +546,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
@ -745,8 +741,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
// Calculates hellaswag score (acc_norm) from prompt // Calculates hellaswag score (acc_norm) from prompt
// //
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
@ -929,7 +923,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
return; return;
} }
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1090,8 +1084,6 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
constexpr int k_min_trailing_ctx = 3; constexpr int k_min_trailing_ctx = 3;
auto data = load_winogrande_from_csv(params.prompt); auto data = load_winogrande_from_csv(params.prompt);
@ -1210,7 +1202,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
return; return;
} }
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1396,8 +1388,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
std::istringstream strstream(params.prompt); std::istringstream strstream(params.prompt);
uint32_t n_task; uint32_t n_task;
strstream.read((char *)&n_task, sizeof(n_task)); strstream.read((char *)&n_task, sizeof(n_task));
@ -1584,7 +1574,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
return; return;
} }
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1681,8 +1671,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
if (params.logits_file.empty()) { if (params.logits_file.empty()) {
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__); LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
return; return;
@ -1776,7 +1764,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
} }
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1); llama_batch batch = llama_batch_init(n_batch, 0, 1);

View File

@ -82,10 +82,8 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
} }
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
llama_kv_cache * kv = llama_get_kv_cache(ctx);
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
// run model // run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View File

@ -743,10 +743,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
// Check if we have enough space in the context to evaluate this batch // Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
llama_kv_cache * kv = llama_get_kv_cache(ctx.get());
const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_kv_cache_used_cells(kv); const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) { if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n"); printf("\033[0m\n");
printe("context size exceeded\n"); printe("context size exceeded\n");

View File

@ -156,8 +156,6 @@ int main(int argc, char ** argv) {
// make new context // make new context
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
llama_kv_cache * kv3 = llama_get_kv_cache(ctx3);
llama_sampler * smpl3 = llama_sampler_chain_init(sparams); llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed)); llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed));
@ -198,7 +196,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
// erase whole kv // erase whole kv
llama_kv_cache_clear(kv3); llama_kv_self_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__); fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1 // restore kv into seq 1

View File

@ -1661,7 +1661,6 @@ struct server_context {
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
llama_kv_cache * kv = nullptr;
const llama_vocab * vocab = nullptr; const llama_vocab * vocab = nullptr;
@ -1722,8 +1721,6 @@ struct server_context {
return false; return false;
} }
kv = llama_get_kv_cache(ctx);
vocab = llama_model_get_vocab(model); vocab = llama_model_get_vocab(model);
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
@ -1961,7 +1958,7 @@ struct server_context {
SRV_DBG("%s", "clearing KV cache\n"); SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache // clear the entire KV cache
llama_kv_cache_clear(kv); llama_kv_self_clear(ctx);
clean_kv_cache = false; clean_kv_cache = false;
} }
@ -2503,8 +2500,8 @@ struct server_context {
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res->t_start = metrics.t_start; res->t_start = metrics.t_start;
res->kv_cache_tokens_count = llama_kv_cache_n_tokens(kv); res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
res->kv_cache_used_cells = llama_kv_cache_used_cells(kv); res->kv_cache_used_cells = llama_kv_self_used_cells(ctx);
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
res->t_prompt_processing_total = metrics.t_prompt_processing_total; res->t_prompt_processing_total = metrics.t_prompt_processing_total;
@ -2620,7 +2617,7 @@ struct server_context {
// Erase token cache // Erase token cache
const size_t n_erased = slot->cache_tokens.size(); const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(kv, slot->id, -1, -1); llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear(); slot->cache_tokens.clear();
auto res = std::make_unique<server_task_result_slot_erase>(); auto res = std::make_unique<server_task_result_slot_erase>();
@ -2688,8 +2685,8 @@ struct server_context {
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (kv, slot.id, n_keep , n_keep + n_discard); llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(kv, slot.id, n_keep + n_discard, slot.n_past, -n_discard); llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@ -2876,8 +2873,8 @@ struct server_context {
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_kv_cache_seq_rm (kv, slot.id, head_p, head_c); llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_cache_seq_add(kv, slot.id, head_c, -1, kv_shift); llama_kv_self_seq_add(ctx, slot.id, head_c, -1, kv_shift);
for (size_t i = 0; i < n_match; i++) { for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@ -2915,9 +2912,9 @@ struct server_context {
} }
// keep only the common part // keep only the common part
if (!llama_kv_cache_seq_rm(kv, slot.id, slot.n_past, -1)) { if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model) // could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(kv, slot.id, -1, -1); llama_kv_self_seq_rm(ctx, slot.id, -1, -1);
// there is no common part left // there is no common part left
slot.n_past = 0; slot.n_past = 0;
@ -3157,7 +3154,7 @@ struct server_context {
slot.cache_tokens.push_back(id); slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(kv, slot.id, slot.n_past, -1); llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) { for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result; completion_token_output result;

View File

@ -88,8 +88,6 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_kv_cache * kv = llama_get_kv_cache(ctx);
// initialize the sampler // initialize the sampler
llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1));
@ -103,7 +101,7 @@ int main(int argc, char ** argv) {
// tokenize the prompt // tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens); std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_kv_cache_used_cells(kv) == 0, true) < 0) { if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_kv_self_used_cells(ctx) == 0, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n"); GGML_ABORT("failed to tokenize the prompt\n");
} }
@ -113,7 +111,7 @@ int main(int argc, char ** argv) {
while (true) { while (true) {
// check if we have enough space in the context to evaluate this batch // check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx); int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_kv_cache_used_cells(kv); int n_ctx_used = llama_kv_self_used_cells(ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) { if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n"); printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n"); fprintf(stderr, "context size exceeded\n");

View File

@ -45,8 +45,6 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt.model.get(); model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get(); ctx_tgt = llama_init_tgt.context.get();
llama_kv_cache * kv = llama_get_kv_cache(ctx_tgt);
const llama_vocab * vocab = llama_model_get_vocab(model_tgt); const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model // load the draft model
@ -219,7 +217,7 @@ int main(int argc, char ** argv) {
{ {
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_kv_cache_seq_rm(kv, 0, n_past, -1); llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1);
} }
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {

View File

@ -90,9 +90,6 @@ int main(int argc, char ** argv) {
model_dft = llama_init_dft.model.get(); model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get(); ctx_dft = llama_init_dft.context.get();
llama_kv_cache * kv_tgt = llama_get_kv_cache(ctx_tgt);
llama_kv_cache * kv_dft = llama_get_kv_cache(ctx_dft);
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
@ -423,14 +420,14 @@ int main(int argc, char ** argv) {
{ {
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(kv_dft, s_keep); llama_kv_self_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (kv_dft, s_keep, 0, -1, -1); llama_kv_self_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(kv_dft, 0); llama_kv_self_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (kv_tgt, s_keep, n_past_tgt, -1); llama_kv_self_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(kv_tgt, s_keep); llama_kv_self_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (kv_tgt, s_keep, 0, -1, -1); llama_kv_self_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(kv_tgt, 0); llama_kv_self_seq_keep(ctx_tgt, 0);
} }
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
@ -447,8 +444,8 @@ int main(int argc, char ** argv) {
common_batch_clear(batch_dft); common_batch_clear(batch_dft);
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(kv_dft, 0, n_past_dft, -1); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(kv_dft, batch_dft).c_str()); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft); llama_decode(ctx_dft, batch_dft);
++n_past_dft; ++n_past_dft;
@ -506,8 +503,8 @@ int main(int argc, char ** argv) {
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) { if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(kv_dft, n_seq_cur, -1, -1); llama_kv_self_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(kv_dft, s, n_seq_cur, -1, -1); llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch // all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) { for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@ -588,9 +585,9 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens // evaluate the target model on the drafted tokens
{ {
llama_kv_cache_seq_keep(kv_tgt, 0); llama_kv_self_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) { for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(kv_tgt, 0, s, -1, -1); llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
} }
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());

View File

@ -469,7 +469,7 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const? LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
LLAMA_API struct llama_kv_cache * llama_get_kv_cache( struct llama_context * ctx); LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
@ -640,28 +640,28 @@ extern "C" {
// Returns the number of tokens in the KV cache (slow, use only for debug) // Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_kv_cache_n_tokens(const struct llama_kv_cache * kv); LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx), DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
"use llama_kv_cache_n_tokens instead"); "use llama_kv_self_n_tokens instead");
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_kv_cache_used_cells(const struct llama_kv_cache * kv); LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx), DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
"use llama_kv_cache_used_cells instead"); "use llama_kv_self_used_cells instead");
// Clear the KV cache - both cell info is erased and KV data is zeroed // Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_cache_clear( LLAMA_API void llama_kv_self_clear(
struct llama_kv_cache * kv); struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence // seq_id < 0 : match any sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_cache_seq_rm( LLAMA_API bool llama_kv_self_seq_rm(
struct llama_kv_cache * kv, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1); llama_pos p1);
@ -670,26 +670,26 @@ extern "C" {
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_cp( LLAMA_API void llama_kv_self_seq_cp(
struct llama_kv_cache * kv, struct llama_context * ctx,
llama_seq_id seq_id_src, llama_seq_id seq_id_src,
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1); llama_pos p1);
// Removes all tokens that do not belong to the specified sequence // Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_seq_keep( LLAMA_API void llama_kv_self_seq_keep(
struct llama_kv_cache * kv, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode() // - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_add( LLAMA_API void llama_kv_self_seq_add(
struct llama_kv_cache * kv, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
@ -698,32 +698,87 @@ extern "C" {
// Integer division of the positions by factor of `d > 1` // Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode() // - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div( LLAMA_API void llama_kv_self_seq_div(
struct llama_kv_cache * kv, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d); int d);
// Returns the largest position present in the KV cache for the specified sequence // Returns the largest position present in the KV cache for the specified sequence
LLAMA_API llama_pos llama_kv_cache_seq_pos_max( LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_kv_cache * kv, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied:
// - lazily on next llama_decode() // - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_self_update()
LLAMA_API void llama_kv_cache_defrag(struct llama_kv_cache * kv); LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
// Check if the context supports KV cache shifting // Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_cache_can_shift(const struct llama_kv_cache * kv); LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_update_kv_cache(struct llama_context * ctx, struct llama_kv_cache * kv); LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_kv_self_clear instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_rm instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_cp instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_keep instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_kv_self_seq_add instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_kv_self_seq_div instead");
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_pos_max instead");
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
"use llama_kv_self_defrag instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
"use llama_kv_self_can_shift instead");
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
"use llama_kv_self_update instead");
// //
// State / sessions // State / sessions

View File

@ -606,7 +606,7 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->model; return &ctx->model;
} }
llama_kv_cache * llama_get_kv_cache(llama_context * ctx) { llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
return &ctx->kv_self; return &ctx->kv_self;
} }
@ -1147,14 +1147,14 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
data_ctx.write_embeddings(ctx); data_ctx.write_embeddings(ctx);
llama_kv_cache::io io = { llama_kv_cache::io io = {
/* .write =*/ [&](const void * src, size_t size) { /* .write = */ [&](const void * src, size_t size) {
data_ctx.write(src, size); data_ctx.write(src, size);
}, },
/* .write_tensor_data =*/ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) { /* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
data_ctx.write_tensor_data(tensor, offset, size); data_ctx.write_tensor_data(tensor, offset, size);
}, },
/* .read =*/ nullptr, /* .read = */ nullptr,
/* .read_to =*/ nullptr, /* .read_to = */ nullptr,
}; };
ctx->kv_self.state_write(io, ctx->model.hparams); ctx->kv_self.state_write(io, ctx->model.hparams);
@ -1195,12 +1195,12 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
data_ctx.read_embeddings(ctx); data_ctx.read_embeddings(ctx);
llama_kv_cache::io io = { llama_kv_cache::io io = {
/* .write =*/ nullptr, /* .write = */ nullptr,
/* .write_tensor_data =*/ nullptr, /* .write_tensor_data = */ nullptr,
/* .read =*/ [&](size_t size) { /* .read = */ [&](size_t size) {
return data_ctx.read(size); return data_ctx.read(size);
}, },
/* .read_to =*/ [&](void * dst, size_t size) { /* .read_to = */ [&](void * dst, size_t size) {
data_ctx.read_to(dst, size); data_ctx.read_to(dst, size);
}, },
}; };
@ -1302,14 +1302,14 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
llama_synchronize(ctx); llama_synchronize(ctx);
llama_kv_cache::io io = { llama_kv_cache::io io = {
/* .write =*/ [&](const void * src, size_t size) { /* .write = */ [&](const void * src, size_t size) {
data_ctx.write(src, size); data_ctx.write(src, size);
}, },
/* .write_tensor_data =*/ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) { /* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
data_ctx.write_tensor_data(tensor, offset, size); data_ctx.write_tensor_data(tensor, offset, size);
}, },
/* .read =*/ nullptr, /* .read = */ nullptr,
/* .read_to =*/ nullptr, /* .read_to = */ nullptr,
}; };
ctx->kv_self.state_write(io, ctx->model.hparams, seq_id); ctx->kv_self.state_write(io, ctx->model.hparams, seq_id);
@ -1336,12 +1336,12 @@ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llam
llama_synchronize(ctx); llama_synchronize(ctx);
llama_kv_cache::io io = { llama_kv_cache::io io = {
/* .write =*/ nullptr, /* .write = */ nullptr,
/* .write_tensor_data =*/ nullptr, /* .write_tensor_data = */ nullptr,
/* .read =*/ [&](size_t size) { /* .read = */ [&](size_t size) {
return data_ctx.read(size); return data_ctx.read(size);
}, },
/* .read_to =*/ [&](void * dst, size_t size) { /* .read_to = */ [&](void * dst, size_t size) {
data_ctx.read_to(dst, size); data_ctx.read_to(dst, size);
}, },
}; };

View File

@ -1072,7 +1072,17 @@ bool llama_kv_cache::state_read_data(const io & io, const llama_hparams & hparam
return true; return true;
} }
///////////// //
// interface implementation
//
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
return kv->n_tokens();
}
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
return kv->used;
}
void llama_kv_cache_clear(llama_kv_cache * kv) { void llama_kv_cache_clear(llama_kv_cache * kv) {
kv->clear(); kv->clear();
@ -1125,14 +1135,6 @@ void llama_kv_cache_defrag(llama_kv_cache * kv) {
kv->defrag(); kv->defrag();
} }
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
return kv->n_tokens();
}
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
return kv->used;
}
bool llama_kv_cache_can_shift(const llama_kv_cache * kv) { bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
return kv->can_shift; return kv->can_shift;
} }

View File

@ -190,6 +190,48 @@ struct llama_kv_slot_restorer {
} }
}; };
// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
void llama_kv_cache_clear(llama_kv_cache * kv);
bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_defrag(llama_kv_cache * kv);
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
// //
// kv cache view // kv cache view
// //

View File

@ -8564,7 +8564,7 @@ static int llama_decode_impl(
// non-causal masks do not use the KV cache // non-causal masks do not use the KV cache
if (hparams.causal_attn) { if (hparams.causal_attn) {
llama_update_kv_cache(&lctx, &lctx.kv_self); // TODO: lctx->update_kv_cache() llama_kv_self_update(&lctx); // TODO: lctx->kv_self_update()
// if we have enough unused cells before the current head -> // if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it // better to start searching from the beginning of the cache, hoping to fill it
@ -9182,9 +9182,12 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0); //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
} }
static void llama_update_kv_cache_impl(llama_context & lctx, llama_kv_cache & kv) { // TODO: move to llama_context
static void llama_kv_self_update_impl(llama_context & lctx) {
bool need_reserve = false; bool need_reserve = false;
auto & kv = lctx.kv_self;
if (kv.has_shift) { if (kv.has_shift) {
if (!kv.can_shift) { if (!kv.can_shift) {
GGML_ABORT("The current context does not support K-shift"); GGML_ABORT("The current context does not support K-shift");
@ -9846,17 +9849,151 @@ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view *
// deprecated // deprecated
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
return llama_kv_self_n_tokens(ctx);
}
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
return llama_kv_cache_n_tokens(&ctx->kv_self); return llama_kv_cache_n_tokens(&ctx->kv_self);
} }
// deprecated // deprecated
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
return llama_kv_self_used_cells(ctx);
}
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
return llama_kv_cache_used_cells(&ctx->kv_self); return llama_kv_cache_used_cells(&ctx->kv_self);
} }
// deprecated
void llama_kv_cache_clear(llama_context * ctx) {
llama_kv_self_clear(ctx);
}
void llama_kv_self_clear(llama_context * ctx) {
llama_kv_cache_clear(&ctx->kv_self);
}
// deprecated
bool llama_kv_cache_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
}
bool llama_kv_self_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
}
// deprecated
void llama_kv_cache_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_self_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
// deprecated
void llama_kv_cache_seq_keep(
llama_context * ctx,
llama_seq_id seq_id) {
return llama_kv_self_seq_keep(ctx, seq_id);
}
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
}
// deprecated
void llama_kv_cache_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
}
void llama_kv_self_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
}
// deprecated
void llama_kv_cache_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
}
void llama_kv_self_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
}
// deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id);
}
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
}
// deprecated
void llama_kv_cache_defrag(llama_context * ctx) {
return llama_kv_self_defrag(ctx);
}
void llama_kv_self_defrag(llama_context * ctx) {
return llama_kv_cache_defrag(&ctx->kv_self);
}
// deprecated
bool llama_kv_cache_can_shift(const llama_context * ctx) {
return llama_kv_self_can_shift(ctx);
}
bool llama_kv_self_can_shift(const llama_context * ctx) {
return llama_kv_cache_can_shift(&ctx->kv_self);
}
// deprecated
void llama_kv_cache_update(llama_context * ctx) {
llama_kv_self_update(ctx);
}
// TODO: move to llama-context // TODO: move to llama-context
void llama_update_kv_cache(llama_context * ctx, llama_kv_cache * kv) { void llama_kv_self_update(llama_context * ctx) {
llama_update_kv_cache_impl(*ctx, *kv); llama_kv_self_update_impl(*ctx);
} }
/// ///