diff --git a/common/common.cpp b/common/common.cpp index 73ff0e85b..657d2ffa8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2417,14 +2417,21 @@ std::tuple llama_init_from_gpt_par } } + const int n_eos = llama_n_eos(llama_get_model(lctx)); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(llama_get_model(lctx), eos_ptr); if (params.ignore_eos) { - params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + for (int32_t i = 0; i < n_eos; ++i) { + params.sparams.logit_bias[eos_ptr[i]] = -INFINITY; + } } if (params.warmup) { LOG("warming up the model with an empty run\n"); - std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + std::vector tmp = { llama_token_bos(model) }; + tmp.insert(tmp.end(), eos_tokens.begin(), eos_tokens.end()); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); @@ -3357,8 +3364,17 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; + const int n_eos = llama_n_eos(llama_get_model(lctx)); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(llama_get_model(lctx), eos_ptr); + bool ignore_eos = false; + for (auto eos: eos_tokens) { + const auto logit_bias_eos = sparams.logit_bias.find(eos); + if (logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY) { + ignore_eos = true; + } + } fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); @@ -3371,7 +3387,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "logit_bias:\n"); for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { + if (ignore_eos && std::count(eos_tokens.begin(), eos_tokens.end(), lb.first)) { continue; } fprintf(stream, " %d: %f", lb.first, lb.second); diff --git a/common/train.cpp b/common/train.cpp index fef1e57c9..96ea4165e 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -240,7 +240,11 @@ int64_t get_example_targets_batch( ggml_set_f32(target_probs, 0.0f); llama_token bos = llama_token_bos(llama_get_model(lctx)); - llama_token eos = llama_token_eos(llama_get_model(lctx)); + const int n_eos = llama_n_eos(llama_get_model(lctx)); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(llama_get_model(lctx), eos_ptr); + llama_token eos = eos_ptr[0]; // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); for (int k=0; k= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) { - inp.push_back(llama_token_eos(model)); + const int n_eos = llama_n_eos(model); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(model, eos_ptr); + + if (!eos_tokens.empty() && (inp.empty() || std::count(eos_tokens.begin(), eos_tokens.end(), inp.back()))) { + inp.insert(inp.end(), eos_tokens.begin(), eos_tokens.end()); } chunk.tokens = inp; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f9a86961f..5be18bc54 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1021,7 +1021,13 @@ struct server_context { slot.sparams.logit_bias.clear(); if (json_value(data, "ignore_eos", false)) { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + const int n_eos = llama_n_eos(model); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(model, eos_ptr); + for (int32_t i = 0; i < n_eos; ++i) { + slot.sparams.logit_bias[eos_ptr[i]] = -INFINITY; + } } const auto & logit_bias = data.find("logit_bias"); @@ -1308,9 +1314,17 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - + const int n_eos = llama_n_eos(model); + std::vector eos_tokens(n_eos, 0); + int32_t* eos_ptr = eos_tokens.data(); + llama_token_eos(model, eos_ptr); + bool ignore_eos = false; + for (auto eos: eos_tokens) { + const auto logit_bias_eos = slot.sparams.logit_bias.find(eos); + if (logit_bias_eos != slot.sparams.logit_bias.end() && eos < 0.0f && std::isinf(logit_bias_eos->second)) { + ignore_eos = true; + } + } std::vector samplers_sequence; samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); for (const auto & sampler_type : slot.sparams.samplers_sequence) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0939a1a6a..43b4278ae 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -88,12 +88,21 @@ int main(int argc, char ** argv) { fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); return 1; } + const int n_eos_tgt = llama_n_eos(model_tgt); + std::vector eos_tokens_tgt(n_eos_tgt, 0); + int32_t* eos_ptr_tgt = eos_tokens_tgt.data(); + llama_token_eos(model_tgt, eos_ptr_tgt); + + const int n_eos_dft = llama_n_eos(model_dft); + std::vector eos_tokens_dft(n_eos_dft, 0); + int32_t* eos_ptr_dft = eos_tokens_dft.data(); + llama_token_eos(model_dft, eos_ptr_dft); if ( llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || llama_token_bos(model_tgt) != llama_token_bos(model_dft) || - llama_token_eos(model_tgt) != llama_token_eos(model_dft) + eos_tokens_tgt != eos_tokens_dft ) { fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__); return 1; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 65aa3298d..cb512c4f2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -88,7 +88,7 @@ class Keys: SCORES = "tokenizer.ggml.scores" MERGES = "tokenizer.ggml.merges" BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" # recommand eos_id_list UNK_ID = "tokenizer.ggml.unknown_token_id" SEP_ID = "tokenizer.ggml.seperator_token_id" PAD_ID = "tokenizer.ggml.padding_token_id" @@ -107,6 +107,8 @@ class Keys: SUFFIX_ID = "tokenizer.ggml.suffix_token_id" MIDDLE_ID = "tokenizer.ggml.middle_token_id" EOT_ID = "tokenizer.ggml.eot_token_id" + EOS_ID_LIST = "tokenizer.ggml.eos_token_id_list" + # @@ -1091,7 +1093,7 @@ KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID -KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID +KEY_TOKENIZER_EOS_ID_LIST= Keys.Tokenizer.EOS_ID_LIST KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a697f657b..e4f6868d9 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -510,9 +510,9 @@ class GGUFWriter: def add_bos_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.BOS_ID, id) - - def add_eos_token_id(self, id: int) -> None: - self.add_uint32(Keys.Tokenizer.EOS_ID, id) + + def add_eos_token_id_list(self, id: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + self.add_array(Keys.Tokenizer.EOS_ID_LIST, id) def add_unk_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.UNK_ID, id) diff --git a/llama.cpp b/llama.cpp index a2ac68379..ea5c76cac 100644 --- a/llama.cpp +++ b/llama.cpp @@ -337,7 +337,7 @@ enum llm_kv { LLM_KV_TOKENIZER_SCORES, LLM_KV_TOKENIZER_MERGES, LLM_KV_TOKENIZER_BOS_ID, - LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_EOS_ID, //compatibility with previous versions LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_PAD_ID, @@ -352,6 +352,7 @@ enum llm_kv { LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOS_ID_LIST }; static const std::map LLM_KV_NAMES = { @@ -438,6 +439,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID_LIST, "tokenizer.ggml.eos_token_id_list" }, }; struct LLM_KV { @@ -2328,6 +2330,7 @@ struct llama_vocab { id special_pad_id = -1; id special_cls_id = -1; id special_mask_id = -1; + std::set special_eos_id_list; id linefeed_id = 13; id special_prefix_id = -1; @@ -5084,6 +5087,24 @@ static void llm_load_vocab( } } + const int eos_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_EOS_ID_LIST).c_str()); + if (eos_idx == -1) { + vocab.special_eos_id_list.clear(); + vocab.special_eos_id_list.insert(vocab.special_eos_id); + } else { + const uint32_t n_eos = gguf_get_arr_n(ctx, eos_idx); + const int* eos_tokens = (const int*)gguf_get_arr_data(ctx, eos_idx); + if (n_eos > 0) { + vocab.special_eos_id_list.clear(); + } else { + vocab.special_eos_id_list.clear(); + vocab.special_eos_id_list.insert(vocab.special_eos_id); + } + for (uint32_t i = 0; i < n_eos; ++i) { + vocab.special_eos_id_list.insert(eos_tokens[i]); + } + } + // Handle add_bos_token and add_eos_token { bool temp = true; @@ -5273,7 +5294,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { // special tokens if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (!vocab.special_eos_id_list.empty()) { + for (auto it = vocab.special_eos_id_list.begin(); it != vocab.special_eos_id_list.end(); ++it) { + LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, *it, vocab.id_to_token[*it].text.c_str() ); + } + } if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } @@ -13482,8 +13507,8 @@ struct llm_tokenizer_bpe { bool append_eos(std::vector & output) const { if (vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); + GGML_ASSERT(!vocab.special_eos_id_list.empty()); + output.insert(output.end(), vocab.special_eos_id_list.begin(), vocab.special_eos_id_list.end()); return true; } return false; @@ -13496,7 +13521,7 @@ struct llm_tokenizer_bpe { "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) { + if (vocab.tokenizer_add_eos && output.size() >= 2 && vocab.special_eos_id_list.find(*(output.end()-2)) != vocab.special_eos_id_list.end()) { LLAMA_LOG_WARN( "%s: Added a EOS token to the prompt as specified by the model but the prompt " "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " @@ -13966,8 +13991,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & } if (add_special && vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); + GGML_ASSERT(!vocab.special_eos_id_list.empty()); + output.insert(output.end(), vocab.special_eos_id_list.begin(), vocab.special_eos_id_list.end()); } // add suffix to chatglm3 if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { @@ -16966,6 +16991,10 @@ int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } +int32_t llama_n_eos(const struct llama_model * model) { + return model->vocab.special_eos_id_list.size(); +} + int32_t llama_n_ctx_train(const struct llama_model * model) { return model->hparams.n_ctx_train; } @@ -18550,21 +18579,8 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to } bool llama_token_is_eog(const struct llama_model * model, llama_token token) { - auto arch_name = llama_model_arch_name(model->arch); - auto vocab_type = model->vocab.type; - if (strcmp(arch_name, "chatglm") == 0) { - if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 - return token != -1 && ( - token == llama_token_eos(model) || - token == llama_token_eot(model) || - token == 151329 || - token == 151336 || - token == 151338 - ); - } - } return token != -1 && ( - token == llama_token_eos(model) || + model->vocab.special_eos_id_list.count(token) || token == llama_token_eot(model) ); } @@ -18577,8 +18593,11 @@ llama_token llama_token_bos(const struct llama_model * model) { return model->vocab.special_bos_id; } -llama_token llama_token_eos(const struct llama_model * model) { - return model->vocab.special_eos_id; +void llama_token_eos(const struct llama_model * model, llama_token* token_list) { + int ind = 0; + for (auto it = model->vocab.special_eos_id_list.begin(); it != model->vocab.special_eos_id_list.end(); ++it) { + token_list[ind++] = *it; + } } llama_token llama_token_cls(const struct llama_model * model) { @@ -18952,10 +18971,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || - (tmpl.find("add_generation_prompt") != std::string::npos && - tmpl.find("for message in messages") != std::string::npos && - tmpl.find("loop.first") != std::string::npos)) { + } else if (tmpl == "chatglm3" || tmpl.find("[gMASK]sop") != std::string::npos) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -18965,7 +18981,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "ChatGLM4") { + } else if (tmpl == "chatglm4" || tmpl.find("[gMASK]") != std::string::npos) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); diff --git a/llama.h b/llama.h index a85b568b9..84172f443 100644 --- a/llama.h +++ b/llama.h @@ -448,6 +448,7 @@ extern "C" { LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); + LLAMA_API int32_t llama_n_eos (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_layer (const struct llama_model * model); @@ -851,7 +852,7 @@ extern "C" { // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence + LLAMA_API void llama_token_eos(const struct llama_model * model, llama_token* token_list); // end-of-sentence LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 0fe4d2967..399dc57b9 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -60,7 +60,7 @@ int main(void) { // ChatGLM3 "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", // ChatGLM4 - "ChatGLM4", + "chatglm4", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B