diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 78681dd25..9fab8266d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -114,35 +114,21 @@ int main(int argc, char ** argv) { struct llama_grammar * grammar_dft = NULL; struct llama_grammar * grammar_tgt = NULL; - grammar_parser::parse_state parsed_grammar_dft; - grammar_parser::parse_state parsed_grammar_tgt; + grammar_parser::parse_state parsed_grammar; std::vector grammar_mem(n_draft, NULL); + // if requested - load the grammar, error checking is omitted for brevity if (!params.grammar.empty()) { - // dft - { - parsed_grammar_dft = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar_dft.rules.empty()) { - return 1; - } - - std::vector grammar_rules(parsed_grammar_dft.c_rules()); - grammar_dft = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar_dft.symbol_ids.at("root")); + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; } - // tgt - { - parsed_grammar_tgt = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar_tgt.rules.empty()) { - return 1; - } - - std::vector grammar_rules(parsed_grammar_tgt.c_rules()); - grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar_tgt.symbol_ids.at("root")); - } + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar_dft = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } const auto t_dec_start = ggml_time_us(); @@ -150,11 +136,12 @@ int main(int argc, char ** argv) { while (true) { LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted)); - // sample from the drafted tokens if any int i_dft = 0; while (true) { + // sample from the target model const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); @@ -170,8 +157,9 @@ int main(int argc, char ** argv) { ++n_predict; + // check if the draft matches the target if (i_dft < (int) drafted.size() && id == drafted[i_dft]) { - LOG("drafted token %d accepted\n", id); + LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); ++n_accept; ++n_past_tgt; ++n_past_dft; @@ -180,25 +168,20 @@ int main(int argc, char ** argv) { continue; } + // the drafted token was rejected or we are out of drafted tokens + if (i_dft < (int) drafted.size()) { - LOG("drafted token %d rejected\n", id); + LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n", + i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str()); if (grammar_mem[i_dft]) { grammar_dft = llama_grammar_copy(grammar_mem[i_dft]); - LOG("restored grammar %d\n", i_dft); + LOG("restored draft grammar state %d\n", i_dft); } + } else { + LOG("out of drafted tokens\n"); } - for (auto & g : grammar_mem) { - if (g) { - llama_grammar_free(g); - g = NULL; - } - } - - LOG("i_dft = %d, drafted.size() = %d\n", i_dft, (int) drafted.size()); - - // the drafted token was rejected or we are out of drafted tokens llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); ++n_past_dft; @@ -212,11 +195,20 @@ int main(int argc, char ** argv) { break; } + for (int i = 0; i < (int) grammar_mem.size(); ++i) { + auto & g = grammar_mem[i]; + if (g) { + LOG("freeing grammar state %d\n", i); + llama_grammar_free(g); + g = NULL; + } + } + if (n_predict > params.n_predict || has_eos) { break; } - // sample n_draft tokens from the draft model picking the best token + // sample n_draft tokens from the draft model using greedy decoding int n_past_cur = n_past_dft; for (int i = 0; i < n_draft; ++i) { // remember the grammar state @@ -244,11 +236,13 @@ int main(int argc, char ** argv) { LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str()); } - // too low probability, stop drafting + // TODO: better logic? if (cur_p.data[0].p < 2*cur_p.data[1].p) { + LOG("stopping drafting, probability too low: %8.f < 2*%8.f\n", cur_p.data[0].p, cur_p.data[1].p); break; } + // drafted token const llama_token id = cur_p.data[0].id; if (grammar_dft != NULL) { @@ -258,17 +252,21 @@ int main(int argc, char ** argv) { drafted.push_back(id); ++n_drafted; - if (i < n_draft - 1) { - // evaluate the drafted token on the draft model - llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); - ++n_past_cur; + // no need to evaluate the last drafted token, since we won't use the result + if (i == n_draft - 1) { + break; } + + // evaluate the drafted token on the draft model + llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + ++n_past_cur; } // evaluate the target model on the drafted tokens llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); ++n_past_tgt; + // the first token is always proposed by the traget model before the speculation loop drafted.erase(drafted.begin()); }