mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
speculative : bug fixes
This commit is contained in:
parent
0e89203b51
commit
4e82b2ea3f
@ -37,8 +37,8 @@ int main(int argc, char ** argv) {
|
|||||||
const int n_seq_dft = params.n_parallel;
|
const int n_seq_dft = params.n_parallel;
|
||||||
|
|
||||||
// TODO: make this configurable
|
// TODO: make this configurable
|
||||||
const float p_accept = 0.4f;
|
const float p_accept = 0.80f;
|
||||||
const float p_split = 0.3f;
|
const float p_split = 0.10f;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("speculative", "log"));
|
log_set_target(log_filename_generator("speculative", "log"));
|
||||||
@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||||
params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
|
params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp);
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
drafts[s].ctx_sampling = llama_sampling_init(params);
|
drafts[s].ctx_sampling = llama_sampling_init(params);
|
||||||
@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
|
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
|
||||||
|
|
||||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||||
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
||||||
|
|
||||||
@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// TODO: simplify
|
// TODO: simplify
|
||||||
{
|
{
|
||||||
LOG("keeping sequence %d\n", s_keep);
|
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
||||||
|
|
||||||
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
||||||
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||||
@ -277,7 +277,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (cur_p[0].p < p_accept) {
|
if (cur_p[0].p < p_accept) {
|
||||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
|
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
|
||||||
drafts[s].drafting = false;
|
drafts[s].drafting = false;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -337,16 +337,14 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
||||||
|
|
||||||
// no need to evaluate the last drafted token, since we won't use the result
|
|
||||||
if (batch_tgt.n_tokens > n_draft) {
|
|
||||||
drafts[s].drafting = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add the token to the batch for batched decoding with the draft model
|
// add the token to the batch for batched decoding with the draft model
|
||||||
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
||||||
|
|
||||||
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
||||||
|
|
||||||
|
if (batch_tgt.n_tokens > n_draft) {
|
||||||
|
drafts[s].drafting = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -365,11 +363,6 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// account for the last drafted token that we didn't evaluate
|
|
||||||
if (batch_tgt.n_tokens > n_draft) {
|
|
||||||
++n_drafted;
|
|
||||||
}
|
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
{
|
{
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||||
|
Loading…
Reference in New Issue
Block a user