From 4e82b2ea3fa6482915d147bc9f46e70b9ada7700 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Oct 2023 18:49:40 +0300 Subject: [PATCH] speculative : bug fixes --- examples/speculative/speculative.cpp | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 53f42fad8..24f49012a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -37,8 +37,8 @@ int main(int argc, char ** argv) { const int n_seq_dft = params.n_parallel; // TODO: make this configurable - const float p_accept = 0.4f; - const float p_split = 0.3f; + const float p_accept = 0.80f; + const float p_split = 0.10f; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); @@ -118,7 +118,7 @@ int main(int argc, char ** argv) { std::vector drafts(n_seq_dft); params.grammar.clear(); // the draft samplers will copy the target sampler's grammar - params.sampling_params.temp = 1.0f; // the draft samplers use default temperature + params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp); for (int s = 0; s < n_seq_dft; ++s) { drafts[s].ctx_sampling = llama_sampling_init(params); @@ -156,7 +156,7 @@ int main(int argc, char ** argv) { llama_sampling_accept(ctx_sampling, ctx_tgt, id); - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens)); + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); const std::string token_str = llama_token_to_piece(ctx_tgt, id); @@ -202,7 +202,7 @@ int main(int argc, char ** argv) { // TODO: simplify { - LOG("keeping sequence %d\n", s_keep); + LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); llama_kv_cache_seq_keep(ctx_dft, s_keep); llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); @@ -277,7 +277,7 @@ int main(int argc, char ** argv) { } if (cur_p[0].p < p_accept) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p); + LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); drafts[s].drafting = false; continue; } @@ -337,16 +337,14 @@ int main(int argc, char ** argv) { llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); - // no need to evaluate the last drafted token, since we won't use the result - if (batch_tgt.n_tokens > n_draft) { - drafts[s].drafting = false; - continue; - } - // add the token to the batch for batched decoding with the draft model drafts[s].i_batch_dft = batch_dft.n_tokens; llama_batch_add(batch_dft, id, n_past_cur, { s }, true); + + if (batch_tgt.n_tokens > n_draft) { + drafts[s].drafting = false; + } } } @@ -365,11 +363,6 @@ int main(int argc, char ** argv) { } } - // account for the last drafted token that we didn't evaluate - if (batch_tgt.n_tokens > n_draft) { - ++n_drafted; - } - // evaluate the target model on the drafted tokens { llama_kv_cache_seq_keep(ctx_tgt, 0);