Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template

This commit is contained in:
ochafik 2025-01-13 19:58:15 +00:00
parent cb72cf1fc3
commit 78861a3eb2
7 changed files with 17 additions and 23 deletions

View File

@ -1822,17 +1822,6 @@ std::string common_chat_format_example(const struct llama_model * model,
return common_chat_apply_template(model, tmpl, msgs, true);
}
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
auto vocab = llama_model_get_vocab(model);
@ -1841,9 +1830,8 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model *
std::string default_template_src = chat_template_override;
std::string tool_use_template_src = chat_template_override;
if (chat_template_override.empty()) {
// TODO:
default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template");
tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
default_template_src = llama_model_chat_template(model, /* name */ nullptr);
tool_use_template_src = llama_model_chat_template(model, /* name */ "tool_use");
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!tool_use_template_src.empty()) {

View File

@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) {
int result = llama_chat_apply_template(
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
}

View File

@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
break;
}
const char * tmpl = llama_model_chat_template(model);
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
// add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())});

View File

@ -503,7 +503,7 @@ extern "C" {
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
// Get the default chat template. Returns nullptr if not available
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);
// Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);

View File

@ -179,6 +179,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@ -1443,10 +1444,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};
LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
std::string LLM_KV::operator()(llm_kv kv) const {
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
}
std::string LLM_TN_IMPL::str() const {

View File

@ -177,6 +177,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID,
@ -335,9 +336,10 @@ enum llm_tensor_layer {
};
struct LLM_KV {
LLM_KV(llm_arch arch);
LLM_KV(llm_arch arch, const char * suffix = nullptr);
llm_arch arch;
const char * suffix;
std::string operator()(llm_kv kv) const;
};

View File

@ -3912,8 +3912,10 @@ uint64_t llama_model_size(const struct llama_model * model) {
return model->size();
}
const char * llama_model_chat_template(const struct llama_model * model) {
const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) {
return nullptr;
}