From 5e6dad9322b47bc072ae7f85513a07f23118a868 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Nov 2024 11:31:31 +0200 Subject: [PATCH] speculative : experimenting with Qwen2.5 --- examples/speculative/speculative.cpp | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 6cafd8a83..769a41287 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -12,7 +12,7 @@ #include #include -#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 struct seq_draft { @@ -188,6 +188,8 @@ int main(int argc, char ** argv) { // draft sequence data std::vector drafts(n_seq_dft); + params.sparams.top_k = std::max(10, params.sparams.top_k); + for (int s = 0; s < n_seq_dft; ++s) { // allocate llama_sampler for each draft sequence drafts[s].smpl = common_sampler_init(model_dft, params.sparams); @@ -346,6 +348,7 @@ int main(int argc, char ** argv) { std::vector probs(dist_tgt.size); for (size_t i = 0; i < dist_tgt.size; ++i) { probs[i] = dist_tgt.data[i].p; + LOG_DBG(" - %d: %f\n", dist_tgt.data[i].id, dist_tgt.data[i].p); } std::discrete_distribution<> dist(probs.begin(), probs.end()); @@ -449,10 +452,13 @@ int main(int argc, char ** argv) { break; } - if (drafts[0].smpl) { - common_sampler_free(drafts[0].smpl); - } - drafts[0].smpl = common_sampler_clone(smpl); + // TODO: this needs better fix - we want the draft samplers to have different parameters from the target sampler + // so we should not copy the target sampler + //if (drafts[0].smpl) { + // common_sampler_free(drafts[0].smpl); + //} + //drafts[0].smpl = common_sampler_clone(smpl); + common_sampler_reset(drafts[0].smpl); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -540,6 +546,12 @@ int main(int argc, char ** argv) { const int s = sa[is]; + // only collect very high-confidence draft tokens + if (cur_p->data[is].p < 0.90) { + drafts[s].drafting = false; + continue; + } + common_sampler_accept(drafts[s].smpl, id, true); drafts[s].tokens.push_back(id); @@ -577,6 +589,12 @@ int main(int argc, char ** argv) { } } + // don't waste time on small batches + if (batch_tgt.n_tokens < 5) { + batch_tgt.n_tokens = 1; + drafts[0].tokens.resize(batch_tgt.n_tokens); + } + // evaluate the target model on the drafted tokens { llama_kv_cache_seq_keep(ctx_tgt, 0);