mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
speculative : experimenting with Qwen2.5
This commit is contained in:
parent
33bdee667e
commit
5e6dad9322
@ -12,7 +12,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||||
|
|
||||||
struct seq_draft {
|
struct seq_draft {
|
||||||
@ -188,6 +188,8 @@ int main(int argc, char ** argv) {
|
|||||||
// draft sequence data
|
// draft sequence data
|
||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
|
params.sparams.top_k = std::max(10, params.sparams.top_k);
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
// allocate llama_sampler for each draft sequence
|
// allocate llama_sampler for each draft sequence
|
||||||
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
||||||
@ -346,6 +348,7 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<float> probs(dist_tgt.size);
|
std::vector<float> probs(dist_tgt.size);
|
||||||
for (size_t i = 0; i < dist_tgt.size; ++i) {
|
for (size_t i = 0; i < dist_tgt.size; ++i) {
|
||||||
probs[i] = dist_tgt.data[i].p;
|
probs[i] = dist_tgt.data[i].p;
|
||||||
|
LOG_DBG(" - %d: %f\n", dist_tgt.data[i].id, dist_tgt.data[i].p);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||||
@ -449,10 +452,13 @@ int main(int argc, char ** argv) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (drafts[0].smpl) {
|
// TODO: this needs better fix - we want the draft samplers to have different parameters from the target sampler
|
||||||
common_sampler_free(drafts[0].smpl);
|
// so we should not copy the target sampler
|
||||||
}
|
//if (drafts[0].smpl) {
|
||||||
drafts[0].smpl = common_sampler_clone(smpl);
|
// common_sampler_free(drafts[0].smpl);
|
||||||
|
//}
|
||||||
|
//drafts[0].smpl = common_sampler_clone(smpl);
|
||||||
|
common_sampler_reset(drafts[0].smpl);
|
||||||
|
|
||||||
int n_seq_cur = 1;
|
int n_seq_cur = 1;
|
||||||
int n_past_cur = n_past_dft;
|
int n_past_cur = n_past_dft;
|
||||||
@ -540,6 +546,12 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const int s = sa[is];
|
const int s = sa[is];
|
||||||
|
|
||||||
|
// only collect very high-confidence draft tokens
|
||||||
|
if (cur_p->data[is].p < 0.90) {
|
||||||
|
drafts[s].drafting = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
common_sampler_accept(drafts[s].smpl, id, true);
|
common_sampler_accept(drafts[s].smpl, id, true);
|
||||||
|
|
||||||
drafts[s].tokens.push_back(id);
|
drafts[s].tokens.push_back(id);
|
||||||
@ -577,6 +589,12 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// don't waste time on small batches
|
||||||
|
if (batch_tgt.n_tokens < 5) {
|
||||||
|
batch_tgt.n_tokens = 1;
|
||||||
|
drafts[0].tokens.resize(batch_tgt.n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
{
|
{
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user