speculative : experimenting with Qwen2.5

This commit is contained in:
Georgi Gerganov 2024-11-14 11:31:31 +02:00
parent 33bdee667e
commit 5e6dad9322
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -12,7 +12,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft { struct seq_draft {
@ -188,6 +188,8 @@ int main(int argc, char ** argv) {
// draft sequence data // draft sequence data
std::vector<seq_draft> drafts(n_seq_dft); std::vector<seq_draft> drafts(n_seq_dft);
params.sparams.top_k = std::max(10, params.sparams.top_k);
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
// allocate llama_sampler for each draft sequence // allocate llama_sampler for each draft sequence
drafts[s].smpl = common_sampler_init(model_dft, params.sparams); drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
@ -346,6 +348,7 @@ int main(int argc, char ** argv) {
std::vector<float> probs(dist_tgt.size); std::vector<float> probs(dist_tgt.size);
for (size_t i = 0; i < dist_tgt.size; ++i) { for (size_t i = 0; i < dist_tgt.size; ++i) {
probs[i] = dist_tgt.data[i].p; probs[i] = dist_tgt.data[i].p;
LOG_DBG(" - %d: %f\n", dist_tgt.data[i].id, dist_tgt.data[i].p);
} }
std::discrete_distribution<> dist(probs.begin(), probs.end()); std::discrete_distribution<> dist(probs.begin(), probs.end());
@ -449,10 +452,13 @@ int main(int argc, char ** argv) {
break; break;
} }
if (drafts[0].smpl) { // TODO: this needs better fix - we want the draft samplers to have different parameters from the target sampler
common_sampler_free(drafts[0].smpl); // so we should not copy the target sampler
} //if (drafts[0].smpl) {
drafts[0].smpl = common_sampler_clone(smpl); // common_sampler_free(drafts[0].smpl);
//}
//drafts[0].smpl = common_sampler_clone(smpl);
common_sampler_reset(drafts[0].smpl);
int n_seq_cur = 1; int n_seq_cur = 1;
int n_past_cur = n_past_dft; int n_past_cur = n_past_dft;
@ -540,6 +546,12 @@ int main(int argc, char ** argv) {
const int s = sa[is]; const int s = sa[is];
// only collect very high-confidence draft tokens
if (cur_p->data[is].p < 0.90) {
drafts[s].drafting = false;
continue;
}
common_sampler_accept(drafts[s].smpl, id, true); common_sampler_accept(drafts[s].smpl, id, true);
drafts[s].tokens.push_back(id); drafts[s].tokens.push_back(id);
@ -577,6 +589,12 @@ int main(int argc, char ** argv) {
} }
} }
// don't waste time on small batches
if (batch_tgt.n_tokens < 5) {
batch_tgt.n_tokens = 1;
drafts[0].tokens.resize(batch_tgt.n_tokens);
}
// evaluate the target model on the drafted tokens // evaluate the target model on the drafted tokens
{ {
llama_kv_cache_seq_keep(ctx_tgt, 0); llama_kv_cache_seq_keep(ctx_tgt, 0);