mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
speculative : experimenting with Qwen2.5
This commit is contained in:
parent
33bdee667e
commit
5e6dad9322
@ -12,7 +12,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
struct seq_draft {
|
||||
@ -188,6 +188,8 @@ int main(int argc, char ** argv) {
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
||||
params.sparams.top_k = std::max(10, params.sparams.top_k);
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
// allocate llama_sampler for each draft sequence
|
||||
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
||||
@ -346,6 +348,7 @@ int main(int argc, char ** argv) {
|
||||
std::vector<float> probs(dist_tgt.size);
|
||||
for (size_t i = 0; i < dist_tgt.size; ++i) {
|
||||
probs[i] = dist_tgt.data[i].p;
|
||||
LOG_DBG(" - %d: %f\n", dist_tgt.data[i].id, dist_tgt.data[i].p);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
@ -449,10 +452,13 @@ int main(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (drafts[0].smpl) {
|
||||
common_sampler_free(drafts[0].smpl);
|
||||
}
|
||||
drafts[0].smpl = common_sampler_clone(smpl);
|
||||
// TODO: this needs better fix - we want the draft samplers to have different parameters from the target sampler
|
||||
// so we should not copy the target sampler
|
||||
//if (drafts[0].smpl) {
|
||||
// common_sampler_free(drafts[0].smpl);
|
||||
//}
|
||||
//drafts[0].smpl = common_sampler_clone(smpl);
|
||||
common_sampler_reset(drafts[0].smpl);
|
||||
|
||||
int n_seq_cur = 1;
|
||||
int n_past_cur = n_past_dft;
|
||||
@ -540,6 +546,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int s = sa[is];
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[is].p < 0.90) {
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
common_sampler_accept(drafts[s].smpl, id, true);
|
||||
|
||||
drafts[s].tokens.push_back(id);
|
||||
@ -577,6 +589,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
// don't waste time on small batches
|
||||
if (batch_tgt.n_tokens < 5) {
|
||||
batch_tgt.n_tokens = 1;
|
||||
drafts[0].tokens.resize(batch_tgt.n_tokens);
|
||||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
{
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user