mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
speculative : add heuristic algorithm (#3006)
* Add heuristic algo for speculative * Constrain minimum n_draft to 2 * speculative : improve heuristic impl * speculative : be more rewarding upon guessing max drafted tokens * speculative : fix typos --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
71ca2fad7d
commit
35f73049af
@ -82,7 +82,7 @@ int main(int argc, char ** argv) {
|
|||||||
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
|
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
|
||||||
|
|
||||||
// how many tokens to draft each time
|
// how many tokens to draft each time
|
||||||
const int n_draft = params.n_draft;
|
int n_draft = params.n_draft;
|
||||||
|
|
||||||
int n_predict = 0;
|
int n_predict = 0;
|
||||||
int n_drafted = 0;
|
int n_drafted = 0;
|
||||||
@ -131,6 +131,7 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
|
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
|
||||||
|
|
||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
|
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
|
||||||
@ -174,6 +175,27 @@ int main(int argc, char ** argv) {
|
|||||||
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
|
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
|
||||||
|
// heuristic for n_draft
|
||||||
|
{
|
||||||
|
const int n_draft_cur = (int) drafted.size();
|
||||||
|
const bool all_accepted = i_dft == n_draft_cur;
|
||||||
|
|
||||||
|
LOG("n_draft = %d\n", n_draft);
|
||||||
|
LOG("n_draft_cur = %d\n", n_draft_cur);
|
||||||
|
LOG("i_dft = %d\n", i_dft);
|
||||||
|
LOG("all_accepted = %d\n", all_accepted);
|
||||||
|
|
||||||
|
if (all_accepted && n_draft == n_draft_cur) {
|
||||||
|
LOG(" - max drafted tokens accepted - n_draft += 8\n");
|
||||||
|
n_draft = std::min(30, n_draft + 8);
|
||||||
|
} else if (all_accepted) {
|
||||||
|
LOG(" - partially drafted tokens accepted - no change\n");
|
||||||
|
} else {
|
||||||
|
LOG(" - drafted token rejected - n_draft -= 1\n");
|
||||||
|
n_draft = std::max(2, n_draft - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
drafted.clear();
|
drafted.clear();
|
||||||
drafted.push_back(id);
|
drafted.push_back(id);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user