From a37696d4f1bdd66aa4bc9f8a1b86ccfff725c603 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 4 Apr 2024 18:25:19 -0400 Subject: [PATCH] speculative : more robust tokenizer comparison --- examples/speculative/speculative.cpp | 34 ++++++++++++++++++---------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index d4a7c3a01..6a7367b0c 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -76,6 +76,28 @@ int main(int argc, char ** argv) { params.n_threads_batch = params.n_threads_batch_draft; std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LOG("vocab_type tgt: %d\n", vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(model_dft); + LOG("vocab_type dft: %d\n", vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__); + fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); + return 1; + } + + if ( + llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft) + ) { + fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__); + return 1; + } + { const int n_vocab_tgt = llama_n_vocab(model_tgt); const int n_vocab_dft = llama_n_vocab(model_dft); @@ -105,18 +127,6 @@ int main(int argc, char ** argv) { // Tokenize the prompt - const bool add_bos_tgt = llama_should_add_bos_token(model_tgt); - LOG("add_bos tgt: %d\n", add_bos_tgt); - - const bool add_bos_dft = llama_should_add_bos_token(model_dft); - LOG("add_bos dft: %d\n", add_bos_dft); - - if (add_bos_tgt != add_bos_dft) { - fprintf(stderr, "%s: error: draft model add_bos must match target model to use speculation but ", __func__); - fprintf(stderr, "add_bos_dft = %d while add_bos_tgt = %d\n", add_bos_dft, add_bos_tgt); - return 1; - } - std::vector inp; inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);