mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 11:23:56 +01:00
speculative : fix KV cache management
This commit is contained in:
parent
7c1bdd0e8a
commit
1f17ea631c
@ -172,6 +172,7 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("out of drafted tokens\n");
|
LOG("out of drafted tokens\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
|
||||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
|
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
|
||||||
@ -217,6 +218,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// sample n_draft tokens from the draft model using greedy decoding
|
// sample n_draft tokens from the draft model using greedy decoding
|
||||||
int n_past_cur = n_past_dft;
|
int n_past_cur = n_past_dft;
|
||||||
|
|
||||||
for (int i = 0; i < n_draft; ++i) {
|
for (int i = 0; i < n_draft; ++i) {
|
||||||
float * logits = llama_get_logits(ctx_dft);
|
float * logits = llama_get_logits(ctx_dft);
|
||||||
|
|
||||||
@ -256,6 +258,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the drafted token on the draft model
|
// evaluate the drafted token on the draft model
|
||||||
|
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
|
||||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
|
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
|
||||||
++n_past_cur;
|
++n_past_cur;
|
||||||
|
|
||||||
@ -265,6 +268,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
|
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
|
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
|
||||||
++n_past_tgt;
|
++n_past_tgt;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user