diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 58e119a3b..0e88fb361 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -33,6 +33,29 @@ struct chat_template_caps { bool requires_typed_content = false; }; +struct chat_template_inputs { + nlohmann::ordered_json messages; + nlohmann::ordered_json tools; + bool add_generation_prompt = true; + nlohmann::ordered_json extra_context; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); +}; + +struct chat_template_options { + bool apply_polyfills = true; + bool use_bos_token = true; + bool use_eos_token = true; + bool define_strftime_now = true; + + bool polyfill_tools = true; + bool polyfill_tool_call_examples = true; + bool polyfill_tool_calls = true; + bool polyfill_tool_responses = true; + bool polyfill_system_role = true; + bool polyfill_object_arguments = true; + bool polyfill_typed_content = true; +}; + class chat_template { private: @@ -41,6 +64,7 @@ class chat_template { std::string bos_token_; std::string eos_token_; std::shared_ptr template_root_; + std::string tool_call_example_; std::string try_raw_render( const nlohmann::ordered_json & messages, @@ -49,7 +73,18 @@ class chat_template { const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { try { - auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + // Use fixed date for tests + inputs.now = std::chrono::system_clock::from_time_t(0); + + chat_template_options opts; + opts.apply_polyfills = false; + + auto prompt = apply(inputs, opts); // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { @@ -176,6 +211,58 @@ class chat_template { caps_.supports_tool_responses = contains(out, "Some response!"); caps_.supports_tool_call_id = contains(out, "call_911_"); } + + try { + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json args { + {"arg1", "some_value"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))}, + }}, + }, + })}, + }; + std::string prefix, full; + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg}); + inputs.add_generation_prompt = true; + prefix = apply(inputs); + } + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg, tool_call_msg}); + inputs.add_generation_prompt = false; + full = apply(inputs); + } + + if (full.find(prefix) != 0) { + if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + } + } + if (full.find(prefix) != 0) { + fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } + tool_call_example_ = full.substr(prefix.size()); + } + } catch (const std::exception & e) { + fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); + } } const std::string & source() const { return source_; } @@ -183,28 +270,72 @@ class chat_template { const std::string & eos_token() const { return eos_token_; } const chat_template_caps & original_caps() const { return caps_; } + // Deprecated, please use the form with chat_template_inputs and chat_template_options std::string apply( const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, bool add_generation_prompt, const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), - bool adjust_inputs = true) const + bool apply_polyfills = true) + { + fprintf(stderr, "[%s] Deprecated!\n", __func__); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + inputs.now = std::chrono::system_clock::now(); + + chat_template_options opts; + opts.apply_polyfills = apply_polyfills; + + return apply(inputs, opts); + } + + std::string apply( + const chat_template_inputs & inputs, + const chat_template_options & opts = chat_template_options()) const { json actual_messages; - auto needs_adjustments = adjust_inputs && (false - || !caps_.supports_system_role - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments - || caps_.requires_typed_content + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tool_calls = false; + auto has_tool_responses = false; + auto has_string_content = false; + for (const auto & message : inputs.messages) { + if (message.contains("tool_calls") && !message["tool_calls"].is_null()) { + has_tool_calls = true; + } + if (message.contains("role") && message["role"] == "tool") { + has_tool_responses = true; + } + if (message.contains("content") && message["content"].is_string()) { + has_string_content = true; + } + } + + auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; + auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; + auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; + auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; + auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; + auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; + auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; + + auto needs_polyfills = opts.apply_polyfills && (false + || polyfill_system_role + || polyfill_tools + || polyfill_tool_calls + || polyfill_tool_responses + || polyfill_object_arguments + || polyfill_typed_content ); - if (needs_adjustments) { + + if (needs_polyfills) { actual_messages = json::array(); auto add_message = [&](const json & msg) { - if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { actual_messages.push_back({ {"role", msg.at("role")}, {"content", {{ @@ -227,9 +358,17 @@ class chat_template { pending_system.clear(); } }; - auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; - for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { + json adjusted_messages; + if (polyfill_tools) { + adjusted_messages = add_system(inputs.messages, + "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); + } else { + adjusted_messages = inputs.messages; + } + + for (const auto & message_ : adjusted_messages) { auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); @@ -237,7 +376,7 @@ class chat_template { std::string role = message.at("role"); if (message.contains("tool_calls")) { - if (caps_.requires_object_arguments || !caps_.supports_tool_calls) { + if (polyfill_object_arguments || polyfill_tool_calls) { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); @@ -252,7 +391,7 @@ class chat_template { } } } - if (!caps_.supports_tool_calls) { + if (polyfill_tool_calls) { auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { @@ -279,7 +418,7 @@ class chat_template { message.erase("tool_calls"); } } - if (!caps_.supports_tool_responses && role == "tool") { + if (polyfill_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { {"tool_response", { @@ -296,7 +435,7 @@ class chat_template { message.erase("name"); } - if (!message["content"].is_null() && !caps_.supports_system_role) { + if (!message["content"].is_null() && polyfill_system_role) { std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; @@ -315,28 +454,36 @@ class chat_template { } add_message(message); } - if (!caps_.supports_system_role) { - flush_sys(); - } + flush_sys(); } else { - actual_messages = messages; + actual_messages = inputs.messages; } auto context = minja::Context::make(json({ {"messages", actual_messages}, - {"add_generation_prompt", add_generation_prompt}, - {"bos_token", bos_token_}, - {"eos_token", eos_token_}, + {"add_generation_prompt", inputs.add_generation_prompt}, })); + context->set("bos_token", opts.use_bos_token ? bos_token_ : ""); + context->set("eos_token", opts.use_eos_token ? eos_token_ : ""); + if (opts.define_strftime_now) { + auto now = inputs.now; + context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { + args.expectArgs("strftime_now", {1, 1}, {0, 0}); + auto format = args.args[0].get(); - if (!tools.is_null()) { - auto tools_val = minja::Value(tools); - context->set("tools", tools_val); + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + return ss.str(); + })); } - if (!extra_context.is_null()) { - for (auto & kv : extra_context.items()) { - minja::Value val(kv.value()); - context->set(kv.key(), val); + if (!inputs.tools.is_null()) { + context->set("tools", minja::Value(inputs.tools)); + } + if (!inputs.extra_context.is_null()) { + for (auto & kv : inputs.extra_context.items()) { + context->set(kv.key(), minja::Value(kv.value())); } } @@ -353,7 +500,7 @@ class chat_template { std::string existing_system = messages_with_system.at(0).at("content"); messages_with_system[0] = json { {"role", "system"}, - {"content", existing_system + "\n" + system_prompt}, + {"content", existing_system + "\n\n" + system_prompt}, }; } else { messages_with_system.insert(messages_with_system.begin(), json { diff --git a/common/chat.cpp b/common/chat.cpp index 4a113c0ca..ef1c6fb3d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -163,6 +163,28 @@ static void foreach_function(const json & tools, const std::function", "<|END_ACTION|>", }; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; return data; } @@ -477,7 +499,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { {"tools_in_user_message", false}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, }); @@ -542,7 +564,8 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ }; builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); }, grammar_options); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = prompt; data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; return data; } @@ -556,10 +579,10 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { fprintf(stderr, "%s\n", __func__); common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { {"datetime", "Jan 29 2025 13:00:00 GMT"}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, - }, /* adjust_inputs= */ false); + }); if (!inputs.tools.is_null() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { @@ -603,7 +626,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (!inputs.tools.is_null() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; @@ -730,7 +753,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con data.grammar_triggers.push_back({"" }; }, grammar_options); - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; return data; } @@ -846,7 +869,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { common_chat_params data; - data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; data.grammar_lazy = false; if (!inputs.json_schema.is_null()) { diff --git a/common/common.cpp b/common/common.cpp index edba6fb4b..8661e164a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1904,10 +1904,6 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model default_template_src = CHATML_TEMPLATE_SRC; } } - std::string token_bos; - std::string token_eos; - // TODO: update logic that adds BOS and EOS tokens to the tokenized prompt, in favour of the template. -#if 0 auto vocab = llama_model_get_vocab(model); const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { if (token == LLAMA_TOKEN_NULL) { @@ -1920,9 +1916,8 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model return common_token_to_piece(vocab, token, true); } }; - token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); - token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); -#endif + auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); try { return { has_explicit_template, diff --git a/common/minja.hpp b/common/minja.hpp index e77eb69d5..c304b5c66 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -2194,7 +2194,7 @@ private: } TemplateTokenVector tokenize() { - static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); @@ -2615,6 +2615,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { auto do_join = [](Value & items, const std::string & sep) { + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); std::ostringstream oss; auto first = true; for (size_t i = 0, n = items.size(); i < n; ++i) { @@ -2695,6 +2696,10 @@ inline std::shared_ptr Context::builtins() { return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + auto filter_fn = context->get(args.args[1]); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); @@ -2772,6 +2777,7 @@ inline std::shared_ptr Context::builtins() { auto & items = args.args[0]; if (items.is_null()) return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); auto attr_name = args.args[1].get(); bool has_test = false; diff --git a/examples/run/run.cpp b/examples/run/run.cpp index ca9273155..39353ba30 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -848,7 +848,15 @@ static int apply_chat_template(const common_chat_template & tmpl, LlamaData & ll }); } try { - auto result = tmpl.apply(messages, /* tools= */ json(), append); + minja::chat_template_inputs tmpl_inputs; + tmpl_inputs.messages = messages; + tmpl_inputs.add_generation_prompt = append; + + minja::chat_template_options tmpl_opts; + tmpl_opts.use_bos_token = false; + tmpl_opts.use_eos_token = false; + + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); llama_data.fmtted.resize(result.size() + 1); memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); return result.size(); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 50bd40738..b78da2cdb 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -18,12 +18,8 @@ using json = nlohmann::ordered_json; static common_chat_msg msg_from_json(const json & message) { - common_chat_msg ret{ - "assistant", - "", - {}, - /* .tool_plan = */ "", - }; + common_chat_msg ret; + ret.role = "assistant"; if (message.contains("content") && !message.at("content").is_null()) { ret.content = message.at("content"); }