diff --git a/.editorconfig b/.editorconfig index eac38a15f..5d63d0a51 100644 --- a/.editorconfig +++ b/.editorconfig @@ -40,3 +40,11 @@ indent_style = tab [examples/cvector-generator/*.txt] trim_trailing_whitespace = unset insert_final_newline = unset + +[models/templates/*.jinja] +indent_style = unset +indent_size = unset +end_of_line = unset +charset = unset +trim_trailing_whitespace = unset +insert_final_newline = unset diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index ed1c357a5..0cbc3d640 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -205,7 +205,7 @@ jobs: run: | cd examples/server/tests $env:PYTHONIOENCODING = ":replace" - pytest -v -x + pytest -v -x -m "not slow" - name: Slow tests id: server_integration_tests_slow diff --git a/Makefile b/Makefile index 295522ba3..ef152d246 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,7 @@ TEST_TARGETS = \ tests/test-arg-parser \ tests/test-autorelease \ tests/test-backend-ops \ + tests/test-chat \ tests/test-chat-template \ tests/test-double-float \ tests/test-grammar-integration \ @@ -983,6 +984,7 @@ OBJ_COMMON = \ $(DIR_COMMON)/ngram-cache.o \ $(DIR_COMMON)/sampling.o \ $(DIR_COMMON)/speculative.o \ + $(DIR_COMMON)/chat.o \ $(DIR_COMMON)/build-info.o \ $(DIR_COMMON)/json-schema-to-grammar.o @@ -1361,6 +1363,8 @@ llama-server: \ examples/server/httplib.h \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ + common/chat.cpp \ + common/chat.hpp \ common/chat-template.hpp \ common/json.hpp \ common/minja.hpp \ @@ -1471,6 +1475,11 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \ $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +tests/test-chat: tests/test-chat.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + tests/test-opt: tests/test-opt.cpp \ $(OBJ_GGML) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/README.md b/README.md index 382c67041..d40309875 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) - **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggerganov/llama.cpp/pull/11427 - **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode +- Universal tool call support in `llama-server`: https://github.com/ggerganov/llama.cpp/pull/9639 - Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim - Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123 - Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669 diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 24b7f8741..72f0915c1 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,8 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat.cpp + chat.hpp chat-template.hpp common.cpp common.h diff --git a/common/chat.cpp b/common/chat.cpp new file mode 100644 index 000000000..d9a654892 --- /dev/null +++ b/common/chat.cpp @@ -0,0 +1,848 @@ +#include "chat.hpp" +#include "chat-template.hpp" +#include "json-schema-to-grammar.h" +#include "log.h" +#include "minja.hpp" + +std::string common_chat_format_name(common_chat_format format) { + switch (format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; + case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo"; + case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; + default: + throw std::runtime_error("Unknown chat format"); + } +} + +const common_grammar_options grammar_options { + /* .dotall = */ false, + /* .compact_spaces = */ false, + // /* .compact_spaces = */ true, +}; + +static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { + // // https://json.nlohmann.me/features/parsing/sax_interface/ + struct json_error_locator : public nlohmann::json_sax { + std::size_t position; + bool found_error; + + json_error_locator() : position(0), found_error(false) {} + + bool parse_error(std::size_t position, const std::string &, const json::exception &) override { + this->position = position - 1; + this->found_error = true; + return false; + } + bool null() override { return true; } + bool boolean(bool) override { return true; } + bool number_integer(number_integer_t) override { return true; } + bool number_unsigned(number_unsigned_t) override { return true; } + bool number_float(number_float_t, const string_t &) override { return true; } + bool string(string_t &) override { return true; } + bool binary(binary_t &) override { return true; } + bool start_object(std::size_t) override { return true; } + bool key(string_t &) override { return true; } + bool end_object() override { return true; } + bool start_array(std::size_t) override { return true; } + bool end_array() override { return true; } + }; + json_error_locator err_loc; + json::sax_parse(it, end, &err_loc); + + std::string::const_iterator temptative_end; + if (err_loc.found_error) { + temptative_end = it + err_loc.position; + } else { + temptative_end = end; + } + std::string json_sub {it, temptative_end}; + try { + out = json::parse(json_sub); + it = temptative_end; + return true; + } catch (const std::exception &) { + return false; + } +} + + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static common_chat_msg parse_json_tool_calls( + const std::string& input, + const std::optional & trigger_opt, + const std::regex & function_regex, + const std::regex & close_regex) { + std::smatch match; + + common_chat_msg result; + result.role = "assistant"; + + + auto end = input.end(); + auto it = input.begin(); + + if (trigger_opt) { + if (!std::regex_search(it, end, match, *trigger_opt)) { + result.content = input; + return result; + } + result.content = match.prefix().str(); + it = match.suffix().first; + } + + while (it != end) { + std::sregex_iterator rend; + std::sregex_iterator rit(it, end, function_regex); + if (rit == rend) { + fprintf(stderr, "No more tool calls found\n"); + result.content += std::string(it, end); + break; + } + auto name = rit->str(1); + result.content += std::string(it, rit->prefix().second); + it = rit->suffix().first; + + json arguments; + if (!parse_json(it, end, arguments)) { + throw std::runtime_error("Failed to parse json tool call arguments"); + } + if (!std::regex_search(it, end, match, close_regex)) { + throw std::runtime_error("Malformed input, missing closing pattern"); + } + it = match.suffix().first; + result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); + } + return result; +} + +static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { + auto content_end = input.find(prefix); + size_t tc_start = std::string::npos; + + common_chat_msg result; + result.role = "assistant"; + const auto process_tool_calls = [&](const json & tool_calls) { + for (const auto & tool_call : tool_calls) { + const auto & arguments = tool_call["arguments"]; + result.tool_calls.push_back({ + tool_call["name"], + arguments.is_string() ? arguments.get() : arguments.dump(), + tool_call.contains("id") ? tool_call["id"] : "", + }); + } + }; + if (content_end == std::string::npos) { + result.content = input; + } else { + tc_start = content_end + prefix.size() - rstrip_prefix; + result.content = input.substr(0, content_end); + auto tool_calls = json::parse(input.substr(tc_start)); + process_tool_calls(tool_calls); + } + return result; +} + +static void foreach_function(const json & tools, const std::function & fn) { + for (const auto & tool : tools) { + if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) { + LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str()); + continue; + } + fn(tool); + } +} + +static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + + auto tool_call_schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + auto tool_schema = json { + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + }}, + {"required", json::array({"name", "arguments"})}, + }; + if (function.contains("description")) { + tool_schema["description"] = function["description"]; + } + if (inputs.parallel_tool_calls) { + tool_schema["properties"]["id"] = { + {"type", "string"}, + {"minLength", 4}, + }; + tool_schema["required"].push_back("id"); + } + tool_call_schemas.emplace_back(tool_schema); + }); + const auto tool_call = + inputs.parallel_tool_calls + ? json { + {"type", "object"}, + {"properties", { + {"tool_calls", { + {"type", "array"}, + {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { + {"anyOf", tool_call_schemas}, + }}, + {"minItems", 1}, + }}, + }}, + {"required", json::array({"tool_calls"})}, + } + : json { + {"type", "object"}, + {"properties", { + {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { + {"anyOf", tool_call_schemas}, + }}, + }}, + {"required", json::array({"tool_call"})}, + }; + const auto schema = + inputs.tool_choice != "required" + ? json { + {"anyOf", json::array({ + tool_call, + { + {"type", "object"}, + {"properties", { + {"response", inputs.json_schema.is_null() + ? json {{"type", "string"}} + : inputs.json_schema + }, + }}, + {"required", json::array({"response"})}, + }, + })} + } + : tool_call; + + data.grammar_lazy = false; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + builder.add_schema("root", schema); + }, grammar_options); + + auto tweaked_messages = common_chat_template::add_system( + inputs.messages, + "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); + + data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_GENERIC; + return data; +} +static common_chat_msg common_chat_parse_generic(const std::string & input) { + json data = json::parse(input); + common_chat_msg result; + result.role = "assistant"; + if (data.contains("tool_calls")) { + for (const auto & tool_call : data["tool_calls"]) { + result.tool_calls.push_back({ + tool_call["name"], + tool_call["arguments"].dump(), + tool_call.contains("id") ? tool_call["id"] : "", + }); + } + } else if (data.contains("tool_call")) { + result.tool_calls.push_back({ + data["tool_call"]["name"], + data["tool_call"]["arguments"].dump(), + /* id= */ "", + }); + } else if (data.contains("response")) { + const auto & response = data["response"]; + result.content = response.is_string() ? response.get() : response.dump(2); + } + return result; +} + +static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + schemas.push_back({ + {"type", "object"}, + {"properties", { + // Important note: the model is probably trained to take a JSON stringified arguments value. + // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + {"id", { + {"type", "string"}, + // Nemo's template expects a 9-character alphanumeric ID. + {"pattern", "^[a-zA-Z0-9]{9}$"}, + }}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); + }, grammar_options); + data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; + return data; +} +static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) { + return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); +} + +static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { + if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) { + throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); + } + const auto & parameters_properties = parameters.at("properties"); + const auto & parameters_required = parameters.at("required"); + for (const auto & prop : expected_properties) { + if (!parameters_properties.contains(prop)) { + throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); + } + if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { + throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); + } + } + if (parameters_properties.size() != expected_properties.size()) { + throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", ")); + } +} + +static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) { + auto builtin_tools = json::array(); + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + + auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { + if (name == "wolfram_alpha") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py + expect_tool_parameters(name, parameters, {"query"}); + } else if (name == "web_search" || name == "brave_search") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py + expect_tool_parameters(name, parameters, {"query"}); + } else if (name == "python" || name == "code_interpreter") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py + expect_tool_parameters(name, parameters, {"code"}); + } else { + return false; + } + + std::vector kvs; + for (const auto & [key, value] : parameters.at("properties").items()) { + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); + } + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\"")); + builtin_tools.push_back(name); + + return true; + }; + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + + // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime + if (allow_python_tag_builtin_tools) { + handle_builtin_tool(name, parameters); + } + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) " + "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + builder.add_schema(name + "-args", parameters) + + " \"}\"")); + data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); + }); + data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true}); + if (!builtin_tools.empty()) { + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + } + builder.add_rule("root", string_join(tool_rules, " | ")); + }, grammar_options); + data.additional_stops.push_back("<|eom_id|>"); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + {"tools_in_user_message", false}, + {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, + }); + data.format = allow_python_tag_builtin_tools && !builtin_tools.empty() + ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS + : COMMON_CHAT_FORMAT_LLAMA_3_X; + return data; +} +static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { + // TODO: tighten & simplify the parser, don't accept leading text context. + static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); + static std::regex close_regex("\\}"); + static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)"); + + if (with_builtin_tools) { + std::smatch match; + if (std::regex_match(input, match, builtin_call_regex)) { + auto name = match[1].str(); + auto raw_args = match[2].str(); + + // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing. + auto it_eq = raw_args.find('='); + auto arg_name = raw_args.substr(0, it_eq); + auto arg_value_str = raw_args.substr(it_eq + 1); + auto arg_value = json::parse(arg_value_str); + + return { + /* .role = */ "assistant", + /* .content = */ match.prefix().str(), + /* .tool_calls = */ { + { + /* .name = */ match[1], + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), + /* .id = */ "", + }, + }, + }; + } + } + return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); +} + +static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); + }); + data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); + builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); + }, grammar_options); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; + return data; +} +static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) { + static std::regex trigger_regex("<|tool▁calls▁begin|>"); + static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); + static std::regex close_regex("```<|tool▁call▁end|>"); + return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); +} + +static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + fprintf(stderr, "%s\n", __func__); + common_chat_params data; + data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { + {"datetime", "Jan 29 2025 13:00:00 GMT"}, + {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, + }, /* adjust_inputs= */ false); + if (!inputs.tools.is_null() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); + }, grammar_options); + data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); + data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2; + } else { + data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + } + return data; +} +static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) { + return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); +} + +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... + // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar + common_chat_params data; + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; + if (!inputs.tools.is_null() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector first_tool_rules; + std::vector subsequent_tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); + subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); + data.grammar_triggers.push_back({name, /* .at_start = */ true}); + data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); + }); + auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; + if (inputs.parallel_tool_calls) { + auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; + builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); + } else { + builder.add_rule("root", first_rule); + } + + }, grammar_options); + } + return data; +} + +static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) { + auto expected_it = expected.begin(); + auto tmp_it = it; + while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) { + ++tmp_it; + ++expected_it; + } + if (expected_it == expected.end()) { + it = tmp_it; + return true; + } + return false; +} + +static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) { + static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); + static std::regex close_regex(R"($|(?=>>>))"); + + std::string content; + auto it = input.begin(); + const auto end = input.end(); + + if (consume(it, end, "all\n")) { + std::smatch match; + if (std::regex_search(it, end, match, function_regex)) { + auto fun_it = match.prefix().second; + content = std::string(it, fun_it); + it = fun_it; + } else { + common_chat_msg res; + res.role = "assistant"; + res.content = std::string(it, end); + return res; + } + } + // TODO: tighten & simplify. + auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex); + res.content = content; + return res; +} + +static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt + common_chat_params data; + json tools = inputs.tools.is_null() ? inputs.tools : json::array(); + std::string python_code_argument_name; + auto has_raw_python = false; + + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + const auto & parameters = function["parameters"]; + std::string name = function["name"]; + if (name == "python" || name == "ipython") { + if (!parameters.contains("type")) { + throw std::runtime_error("Missing type in python tool"); + } + has_raw_python = true; + auto type = parameters.at("type"); + if (type == "object") { + auto properties = parameters.at("properties"); + for (auto it = properties.begin(); it != properties.end(); ++it) { + if (it.value().at("type") == "string") { + if (!python_code_argument_name.empty()) { + throw std::runtime_error("Multiple string arguments found in python tool"); + } + python_code_argument_name = it.key(); + } + } + if (python_code_argument_name.empty()) { + throw std::runtime_error("No string argument found in python tool"); + } + } else if (type != "string") { + throw std::runtime_error("Invalid type in python tool: " + type.dump()); + } + } + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); + }); + if (has_raw_python) { + tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + } + auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + data.grammar_triggers.push_back({"([\s\S\n]*)$)"); + std::smatch match; + if (std::regex_search(input, match, python_tag_regex)) { + auto code = match[1].str(); + return { + /* .role = */ "assistant", + /* .content = */ match.prefix().str(), + /* .tool_calls = */ { + { + /* .name = */ "python", + /* .arguments = */ (json {{"code", code}}).dump(), + /* .id = */ "", + }, + } + }; + } + static std::regex function_regex(R"()"); + static std::regex close_regex(R"()"); + // TODO: tighten & simplify. + return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); +} + +static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + // (content)?({"name": "foo", "arguments": {"a": 1}})* + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + }); + auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + data.grammar_triggers.push_back({"", /* .at_start = */ false}); + // Not really a trigger but need to print this special token to get a successful parse. + data.grammar_triggers.push_back({"", /* .at_start = */ false}); + }, grammar_options); + + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; + return data; +} +static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) { + try { + std::regex start_pattern(R"([\n\s]*)"); + std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); + std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); + + auto end = input.end(); + std::sregex_iterator rend; + std::sregex_iterator rit(input.begin(), end, start_pattern); + if (rit == rend) { + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; + } + + common_chat_msg result; + result.role = "assistant"; + result.content = rit->prefix(); + + auto it = rit->suffix().first; + while (it != end) { + json call; + if (!parse_json(it, end, call)) { + throw std::runtime_error("Failed to parse json tool call"); + } + const auto & arguments = call["arguments"]; + result.tool_calls.push_back({ + call["name"], + arguments.dump(), + // arguments.is_string() ? arguments.get() : arguments.dump(), + /* id= */ "", + }); + rit = {it, end, middle_pattern}; + if (rit != rend) { + it = rit->suffix().first; + } else { + rit = {it, end, end_pattern}; + if (rit == rend) { + throw std::runtime_error("Malformed input, missing "); + } + break; + } + } + return result; + } catch (const std::exception & e) { + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; + } +} + +static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + data.grammar_lazy = false; + if (!inputs.json_schema.is_null()) { + if (!inputs.grammar.empty()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + data.grammar = json_schema_to_grammar(inputs.json_schema); + } else { + data.grammar = inputs.grammar.empty(); + } + return data; +} + +common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none"; + LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false"); + + if (has_tools && !inputs.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); + } + + const auto & src = tmpl.source(); + if (src.find(">>>all") != std::string::npos) { + // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when + return common_chat_params_init_functionary_v3_2(tmpl, inputs); + } + if (src.find(" functools[") != std::string::npos) { + // Firefunction v2 requires datetime and functions in the context, even w/o tools. + return common_chat_params_init_firefunction_v2(tmpl, inputs); + } + + if (!has_tools) { + return common_chat_params_init_without_tools(tmpl, inputs); + } + + if (src.find("") != std::string::npos) { + return common_chat_params_init_hermes_2_pro(tmpl, inputs); + } + if (src.find("<|start_header_id|>") != std::string::npos + && src.find("ipython<|end_header_id|>") != std::string::npos) { + auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; + return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); + } + if (src.find("<|tool▁calls▁begin|>") != std::string::npos) { + return common_chat_params_init_deepseek_r1(tmpl, inputs); + } + if (src.find("[TOOL_CALLS]") != std::string::npos) { + return common_chat_params_init_mistral_nemo(tmpl, inputs); + } + return common_chat_params_init_generic(tmpl, inputs); +} + +static common_chat_msg common_chat_parse_content_only(const std::string & input) { + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; +} + +common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { + switch (format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + return common_chat_parse_content_only(input); + case COMMON_CHAT_FORMAT_GENERIC: + return common_chat_parse_generic(input); + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: + return common_chat_parse_mistral_nemo(input); + case COMMON_CHAT_FORMAT_LLAMA_3_X: + return common_chat_parse_llama_3_1(input); + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: + return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true); + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: + return common_chat_parse_deepseek_r1(input); + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: + return common_chat_parse_functionary_v3_2(input); + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: + return common_chat_parse_functionary_v3_1_llama_3_1(input); + case COMMON_CHAT_FORMAT_HERMES_2_PRO: + return common_chat_parse_hermes_2_pro(input); + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: + return common_chat_parse_firefunction_v2(input); + default: + throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); + } +} diff --git a/common/chat.hpp b/common/chat.hpp new file mode 100644 index 000000000..ca165aa13 --- /dev/null +++ b/common/chat.hpp @@ -0,0 +1,50 @@ +// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. + +#pragma once + +#include "common.h" +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +struct common_chat_inputs { + json messages; + json tools; + json tool_choice; + json json_schema; + bool parallel_tool_calls; + bool stream; + std::string grammar; + bool add_generation_prompt = true; +}; + +enum common_chat_format { + COMMON_CHAT_FORMAT_CONTENT_ONLY, + COMMON_CHAT_FORMAT_GENERIC, + COMMON_CHAT_FORMAT_MISTRAL_NEMO, + COMMON_CHAT_FORMAT_LLAMA_3_X, + COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, + COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_FIREFUNCTION_V2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, + COMMON_CHAT_FORMAT_HERMES_2_PRO, + + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats +}; + +struct common_chat_params { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + json prompt; + std::string grammar; + bool grammar_lazy = false; + std::vector grammar_triggers; + std::vector additional_stops; +}; + +struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params); +std::string common_chat_format_name(common_chat_format format); +common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); diff --git a/common/common.cpp b/common/common.cpp index 6dea8e3d2..6c81d18f9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,6 +12,7 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" +#include "chat.hpp" #include "chat-template.hpp" #include @@ -1774,11 +1775,13 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { - auto chat_template = minja::chat_template(tmpl, "", ""); - chat_template.apply({{ + auto chat_template = common_chat_template(tmpl, "", ""); + common_chat_inputs inputs; + inputs.messages = json::array({{ {"role", "user"}, {"content", "test"}, - }}, json(), true); + }}); + common_chat_params_init(chat_template, inputs); return true; } catch (const std::exception & e) { LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); @@ -1800,7 +1803,10 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } - return tmpl.apply(messages, /* tools= */ json(), add_ass); + common_chat_inputs inputs; + inputs.messages = messages; + inputs.add_generation_prompt = add_ass; + return common_chat_params_init(tmpl, inputs).prompt; } int alloc_size = 0; @@ -1855,10 +1861,10 @@ std::string common_chat_format_single( std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { std::vector msgs = { - {"system", "You are a helpful assistant"}, - {"user", "Hello"}, - {"assistant", "Hi there"}, - {"user", "How are you?"}, + {"system", "You are a helpful assistant", {}}, + {"user", "Hello", {}}, + {"assistant", "Hi there", {}}, + {"user", "How are you?", {}}, }; return common_chat_apply_template(tmpl, msgs, true, use_jinja); } diff --git a/common/common.h b/common/common.h index 571260372..6c1809277 100644 --- a/common/common.h +++ b/common/common.h @@ -109,6 +109,11 @@ enum common_conversation_mode { COMMON_CONVERSATION_MODE_AUTO = 2, }; +struct common_grammar_trigger { + std::string word; + bool at_start; +}; + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -154,7 +159,10 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_TEMPERATURE, }; - std::string grammar; // optional BNF-like grammar to constrain sampling + std::string grammar; // optional BNF-like grammar to constrain sampling + bool grammar_lazy = false; + std::vector grammar_trigger_words; // optional trigger words to trigger lazy grammar + std::vector grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens. std::vector logit_bias; // logit biases to apply @@ -602,10 +610,17 @@ std::string common_detokenize( // Chat template utils // +struct common_tool_call { + std::string name; + std::string arguments; + std::string id; +}; + // same with llama_chat_message, but uses std::string struct common_chat_msg { std::string role; std::string content; + std::vector tool_calls; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 4d426b6bd..1f47e313e 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -343,7 +343,7 @@ static std::string format_literal(const std::string & literal) { class SchemaConverter { private: - friend std::string build_grammar(const std::function & cb); + friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; std::map _rules; @@ -764,10 +764,11 @@ private: public: SchemaConverter( const std::function & fetch_json, - bool dotall) + bool dotall, + bool compact_spaces) : _fetch_json(fetch_json), _dotall(dotall) { - _rules["space"] = SPACE_RULE; + _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE; } void resolve_refs(json & schema, const std::string & url) { @@ -991,16 +992,16 @@ public: }; std::string json_schema_to_grammar(const json & schema) { - return build_grammar([&](const llama_grammar_builder & callbacks) { + return build_grammar([&](const common_grammar_builder & callbacks) { auto copy = schema; callbacks.resolve_refs(copy); callbacks.add_schema("", copy); }); } -std::string build_grammar(const std::function & cb) { - SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false); - llama_grammar_builder builder { +std::string build_grammar(const std::function & cb, const common_grammar_options & options) { + SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces); + common_grammar_builder builder { /* .add_rule = */ [&](const std::string & name, const std::string & rule) { return converter._add_rule(name, rule); }, diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 4f43ab3a5..ba4112cb9 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -7,10 +7,15 @@ std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); -struct llama_grammar_builder { +struct common_grammar_builder { std::function add_rule; std::function add_schema; std::function resolve_refs; }; -std::string build_grammar(const std::function & cb); +struct common_grammar_options { + bool dotall = false; + bool compact_spaces = false; +}; + +std::string build_grammar(const std::function & cb, const common_grammar_options & options = {}); diff --git a/common/sampling.cpp b/common/sampling.cpp index 7241ac321..bc7e49fdb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -151,9 +151,18 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co lparams.no_perf = params.no_perf; + std::vector trigger_words; + trigger_words.reserve(params.grammar_trigger_words.size()); + for (const auto & str : params.grammar_trigger_words) { + trigger_words.push_back(str.word.c_str()); + } auto * result = new common_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"), + /* .grmr = */ params.grammar_lazy + ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", + trigger_words.data(), trigger_words.size(), + params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) + : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"), /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 17a0e27c4..a610e6a0b 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -76,7 +76,7 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } - llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0); if (grammar == nullptr) { fprintf(stdout, "Failed to initialize llama_grammar\n"); return 1; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index da2a03ab9..e654d3542 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -254,7 +254,7 @@ int main(int argc, char ** argv) { } } - const bool add_bos = llama_vocab_get_add_bos(vocab); + const bool add_bos = llama_vocab_get_add_bos(vocab) && !params.use_jinja; if (!llama_model_has_encoder(model)) { GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); } @@ -264,9 +264,9 @@ int main(int argc, char ** argv) { std::vector embd_inp; auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { - common_chat_msg new_msg{role, content}; + common_chat_msg new_msg{role, content, {}}; auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); - chat_msgs.push_back({role, content}); + chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; }; @@ -503,12 +503,14 @@ int main(int argc, char ** argv) { std::vector embd; - // tokenized antiprompts - std::vector> antiprompt_ids; + // single-token antiprompts + std::vector antiprompt_token; - antiprompt_ids.reserve(params.antiprompt.size()); for (const std::string & antiprompt : params.antiprompt) { - antiprompt_ids.emplace_back(::common_tokenize(ctx, antiprompt, false, true)); + auto ids = ::common_tokenize(ctx, antiprompt, false, true); + if (ids.size() == 1) { + antiprompt_token.push_back(ids[0]); + } } if (llama_model_has_encoder(model)) { @@ -753,14 +755,11 @@ int main(int argc, char ** argv) { // check for reverse prompt using special tokens llama_token last_token = common_sampler_last(smpl); - for (std::vector ids : antiprompt_ids) { - if (ids.size() == 1 && last_token == ids[0]) { - if (params.interactive) { - is_interacting = true; - } - is_antiprompt = true; - break; + if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) { + if (params.interactive) { + is_interacting = true; } + is_antiprompt = true; } if (is_antiprompt) { diff --git a/examples/server/README.md b/examples/server/README.md index 44da503df..ce1ae8858 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1117,6 +1117,82 @@ curl http://localhost:8080/v1/chat/completions \ }' ``` +... and even tool usage (needs `--jinja` flag): + + ```shell + llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa + + # https://huggingface.co/meetkai/functionary-medium-v3.2 + llama-server --jinja -hfr bartowski/functionary-medium-v3.2-GGUF -hff functionary-medium-v3.2-IQ4_XS.gguf -fa + + # https://huggingface.co/meetkai/functionary-medium-v3.1 + llama-server --jinja -hfr meetkai/functionary-medium-v3.1-GGUF -hff functionary-medium-llama-3.1.Q4_0.gguf -fa + + curl http://localhost:8080/v1/chat/completions -d '{ + "model": "gpt-3.5-turbo", + "tools": [ + { + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and state, e.g. San Francisco, CA" + } + }, + "required":["location"] + } + } + } + ], + "messages": [ + { + "role": "user", + "content": "What is the weather like in Istanbul?." + } + ] + }' + ``` + +
+ Show output + + ```json + { + "choices": [ + { + "finish_reason": "tool", + "index": 0, + "message": { + "content": null, + "tool_calls": [ + { + "name": "python", + "arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}" + } + ], + "role": "assistant" + } + } + ], + "created": 1727287211, + "model": "gpt-3.5-turbo", + "object": "chat.completion", + "usage": { + "completion_tokens": 16, + "prompt_tokens": 44, + "total_tokens": 60 + }, + "id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8" + } + ``` + +
+ ### POST `/v1/embeddings`: OpenAI-compatible embeddings API This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b9aa5c81c..d1ea343dd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -113,10 +113,11 @@ struct slot_params { struct common_params_speculative speculative; // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; json to_json() const { std::vector samplers; @@ -164,6 +165,8 @@ struct slot_params { {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, + // {"grammar_trigger_words", sampling.grammar_trigger_words}, + {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -325,12 +328,50 @@ struct server_task { if (data.contains("json_schema") && !data.contains("grammar")) { try { auto schema = json_value(data, "json_schema", json::object()); - params.sampling.grammar = json_schema_to_grammar(schema); + LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); } catch (const std::exception & e) { throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); } } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_format = static_cast(it->get()); + LOG_DBG("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str()); + } else { + params.oaicompat_chat_format = defaults.oaicompat_chat_format; + } + } + + { + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + common_grammar_trigger trigger; + trigger.word = t.at("word"); + trigger.at_start = t.at("at_start"); + + auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); + params.sampling.grammar_trigger_tokens.push_back(ids[0]); + continue; + } + LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); + params.sampling.grammar_trigger_words.push_back(trigger); + } + } + if (params.sampling.grammar_lazy) { + GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0); + } } { @@ -382,22 +423,12 @@ struct server_task { } { - const auto & samplers = data.find("samplers"); + const auto samplers = data.find("samplers"); if (samplers != data.end()) { if (samplers->is_array()) { - std::vector sampler_names; - for (const auto & name : *samplers) { - if (name.is_string()) { - sampler_names.emplace_back(name); - } - } - params.sampling.samplers = common_sampler_types_from_names(sampler_names, false); + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); } else if (samplers->is_string()){ - std::string sampler_string; - for (const auto & name : *samplers) { - sampler_string += name; - } - params.sampling.samplers = common_sampler_types_from_chars(sampler_string); + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); } } else { params.sampling.samplers = defaults.sampling.samplers; @@ -544,7 +575,7 @@ struct completion_token_output { struct server_task_result_cmpl_final : server_task_result { int index = 0; - std::string content; + std::string content; llama_tokens tokens; bool stream; @@ -566,10 +597,11 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; virtual int get_index() override { return index; @@ -663,18 +695,38 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_chat() { std::string finish_reason = "length"; + common_chat_msg message; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; + message = common_chat_parse(content, oaicompat_chat_format); + finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls"; + } else { + message.content = content; } - json choice = json{ + json tool_calls; + if (!message.tool_calls.empty()) { + tool_calls = json::array(); + for (const auto & tc : message.tool_calls) { + tool_calls.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id.empty() ? json() : json(tc.id)}, + }); + } + } + + json choice { {"finish_reason", finish_reason}, {"index", 0}, {"message", json { - {"content", content}, - {"role", "assistant"} - } - }}; + {"content", message.content}, + {"tool_calls", tool_calls}, + {"role", "assistant"}, + }}, + }; if (!stream && probs_output.size() > 0) { choice["logprobs"] = json{ @@ -716,7 +768,7 @@ struct server_task_result_cmpl_final : server_task_result { finish_reason = "stop"; } - json choice = json{ + json choice = json { {"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()} @@ -1191,6 +1243,8 @@ struct server_slot { llama_token sampled; + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + // stats size_t n_sent_text = 0; // number of sent text character @@ -1815,17 +1869,16 @@ struct server_context { if (use_jinja) { auto templates = common_chat_templates_from_model(model, ""); + common_chat_inputs inputs; + inputs.messages = json::array({{ + {"role", "user"}, + {"content", "test"}, + }}); GGML_ASSERT(templates.template_default); try { - templates.template_default->apply({{ - {"role", "user"}, - {"content", "test"}, - }}, json(), true); + common_chat_params_init(*templates.template_default, inputs); if (templates.template_tool_use) { - templates.template_tool_use->apply({{ - {"role", "user"}, - {"content", "test"}, - }}, json(), true); + common_chat_params_init(*templates.template_tool_use, inputs); } return true; } catch (const std::exception & e) { @@ -2275,11 +2328,11 @@ struct server_context { res->id_slot = slot.id; res->index = slot.index; - res->content = slot.generated_text; - res->tokens = slot.generated_tokens; + res->content = std::move(slot.generated_text); + res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); - res->response_fields = slot.params.response_fields; + res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; @@ -2290,12 +2343,12 @@ struct server_context { res->stop = slot.stop; res->post_sampling_probs = slot.params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->stream = slot.params.stream; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_chat_format = slot.params.oaicompat_chat_format; // populate res.probs_output if (slot.params.sampling.n_probs > 0) { if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { @@ -2773,6 +2826,11 @@ struct server_context { // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; + auto accept_special_token = [&](server_slot & slot, llama_token token) { + const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens; + return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end(); + }; + // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { @@ -3136,7 +3194,7 @@ struct server_context { completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.params.sampling.n_probs > 0) { @@ -3225,7 +3283,7 @@ struct server_context { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs @@ -3722,6 +3780,8 @@ int main(int argc, char ** argv) { { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, { "chat_template", ctx_server.chat_templates.template_default->source() }, + { "bos_token", ctx_server.chat_templates.template_default->bos_token() }, + { "eos_token", ctx_server.chat_templates.template_default->eos_token() }, { "build_info", build_info }, }; if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { @@ -3763,7 +3823,9 @@ int main(int argc, char ** argv) { std::vector tasks; try { - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true); + const auto & prompt = data.at("prompt"); + LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(type); @@ -3779,8 +3841,8 @@ int main(int argc, char ** argv) { task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(task); @@ -3949,14 +4011,14 @@ int main(int argc, char ** argv) { }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + LOG_DBG("request: %s\n", req.body.c_str()); if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } auto body = json::parse(req.body); - const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; - json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); + json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, @@ -3966,6 +4028,13 @@ int main(int argc, char ** argv) { OAICOMPAT_TYPE_CHAT); }; + // same with handle_chat_completions, but without inference part + const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { + auto body = json::parse(req.body); + json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); + res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); + }; + const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, @@ -4124,14 +4193,6 @@ int main(int argc, char ** argv) { res_ok(res, root); }; - const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { - auto body = json::parse(req.body); - const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; - json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); - - res_ok(res, {{ "prompt", data.at("prompt") }}); - }; - const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); }; diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index 5787276ab..1de0eb30e 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -31,8 +31,9 @@ It's possible to override some scenario steps values with environment variables: | `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` | | `DEBUG` | to enable steps and server verbose mode `--verbose` | | `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` | +| `LLAMA_CACHE` | by default server tests re-download models to the `tmp` subfolder. Set this to your cache (e.g. `$HOME/Library/Caches/llama.cpp` on Mac or `$HOME/.cache/llama.cpp` on Unix) to avoid this | -To run slow tests: +To run slow tests (will download many models, make sure to set `LLAMA_CACHE` if needed): ```shell SLOW_TESTS=1 ./tests.sh @@ -44,10 +45,16 @@ To run with stdout/stderr display in real time (verbose output, but useful for d DEBUG=1 ./tests.sh -s -v -x ``` -To run single test unit: +To run all the tests in a file: ```shell -./tests.sh unit/test_{name of test case here}.py -v -x +./tests.sh unit/test_chat_completion.py.py -v -x +``` + +To run a single test: + +```shell +./tests.sh unit/test_chat_completion.py::test_invalid_chat_completion_req ``` Hint: You can compile and run test in single command, useful for local developement: diff --git a/examples/server/tests/pytest.ini b/examples/server/tests/pytest.ini new file mode 100644 index 000000000..6df308df7 --- /dev/null +++ b/examples/server/tests/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + serial diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 1e0777de3..33fa8cc64 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -6,9 +6,18 @@ cd $SCRIPT_DIR set -eu +if [[ "${SLOW_TESTS:-0}" == 1 ]]; then + # Slow tests for tool calls need quite a few models ahead of time to avoid timing out. + python $SCRIPT_DIR/../../../scripts/fetch_server_test_models.py +fi + if [ $# -lt 1 ] then - pytest -v -x + if [[ "${SLOW_TESTS:-0}" == 1 ]]; then + pytest -v -x + else + pytest -v -x -m "not slow" + fi else pytest "$@" fi diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index add3f810f..0be04bab5 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -2,7 +2,7 @@ import pytest from openai import OpenAI from utils import * -server = ServerPreset.tinyllama2() +server: ServerProcess @pytest.fixture(autouse=True) def create_server(): @@ -13,11 +13,12 @@ def create_server(): @pytest.mark.parametrize( "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template", [ - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None), - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None), - (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + # TODO: fix testing of non-tool jinja mode + # (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), + # (None, "Book", "What is the best book", 8, "I want to play with", 23, 8, "length", True, "This is not a chat template, it is"), + # ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ] ) def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py new file mode 100644 index 000000000..e6ed9c9be --- /dev/null +++ b/examples/server/tests/unit/test_tool_call.py @@ -0,0 +1,352 @@ +import pytest +from utils import * + +server: ServerProcess + +TIMEOUT_SERVER_START = 15*60 +TIMEOUT_HTTP_REQUEST = 60 + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.model_alias = "tinyllama-2-tool-call" + server.server_port = 8081 + + +TEST_TOOL = { + "type":"function", + "function": { + "name": "test", + "description": "", + "parameters": { + "type": "object", + "properties": { + "success": {"type": "boolean", "const": True}, + }, + "required": ["success"] + } + } +} + +PYTHON_TOOL = { + "type": "function", + "function": { + "name": "python", + "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code to run in the ipython interpreter." + } + }, + "required": ["code"] + } + } +} + +WEATHER_TOOL = { + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'" + } + }, + "required":["location"] + } + } +} + + +def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): + n_predict = 512 + global server + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), +]) +def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None): + do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + + +@pytest.mark.slow +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), + ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), + ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), + ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), + ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"), + ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), +]) +def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None): + do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + + +@pytest.mark.slow +@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [ + (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + # TODO: fix these + # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), +]) +def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None): + n_predict = 512 + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + global server + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": tools if tools else None, + "tool_choice": tool_choice, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + + +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + + +@pytest.mark.slow +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meetkai-functionary-medium-v3.2", 256, [], None), + ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'), + ("meetkai-functionary-medium-v3.1", 256, [], None), + ("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [], None), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + + +@pytest.mark.slow +@pytest.mark.parametrize("hf_repo,template_override", [ + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), +]) +def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 512 + server.model_hf_repo = hf_repo + server.model_hf_file = None + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 256, + "messages": [ + {"role": "user", "content": "What is the weather in Istanbul?"}, + ], + "tools": [WEATHER_TOOL], + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" + location = actual_arguments["location"] + assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" + assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + + +@pytest.mark.slow +@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ + (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), + ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), +]) +def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 128 + server.model_hf_repo = hf_repo + server.model_hf_file = None + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 256, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": [PYTHON_TOOL], + # Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test. + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + if expected_arguments_override is not None: + assert actual_arguments == expected_arguments_override + else: + actual_arguments = json.loads(actual_arguments) + assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" + code = actual_arguments["code"] + assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" + assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 9964db2f9..ce0680662 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -26,7 +26,7 @@ from re import RegexFlag import wget -DEFAULT_HTTP_TIMEOUT = 10 if "LLAMA_SANITIZE" not in os.environ else 30 +DEFAULT_HTTP_TIMEOUT = 12 if "LLAMA_SANITIZE" not in os.environ else 30 class ServerResponse: @@ -41,7 +41,7 @@ class ServerProcess: server_port: int = 8080 server_host: str = "127.0.0.1" model_hf_repo: str = "ggml-org/models" - model_hf_file: str = "tinyllamas/stories260K.gguf" + model_hf_file: str | None = "tinyllamas/stories260K.gguf" model_alias: str = "tinyllama-2" temperature: float = 0.8 seed: int = 42 @@ -191,7 +191,7 @@ class ServerProcess: creationflags=flags, stdout=sys.stdout, stderr=sys.stdout, - env={**os.environ, "LLAMA_CACHE": "tmp"}, + env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None, ) server_instances.add(self) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c5987250c..3d2c04666 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -17,6 +17,7 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" #include "minja.hpp" +#include "chat.hpp" #include "chat-template.hpp" #include @@ -376,7 +377,7 @@ inline std::string format_chat(const common_chat_template & tmpl, const std::vec throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - chat.push_back({role, content}); + chat.push_back({role, content, /* tool_calls= */ {}}); } const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); @@ -483,14 +484,13 @@ static bool ends_with(const std::string & str, const std::string & suffix) { static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { if (!text.empty() && !stop.empty()) { - const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { - if (stop[char_index] == text_last_char) { - const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) { - return text.size() - char_index - 1; - } + auto it = std::find(stop.rbegin(), stop.rend(), text.back()); + while (it != stop.rend()) { + size_t length = std::distance(it, stop.rend()); + if (text.length() >= length && 0 == text.compare(text.length() - length, length, stop)) { + return text.length() - length; } + it = std::find(std::next(it), stop.rend(), text.back()); } } @@ -580,21 +580,30 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ - const common_chat_template & tmpl, - bool use_jinja) + bool use_jinja, + const common_chat_templates & chat_templates) { json llama_params; + const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use + ? *chat_templates.template_tool_use + : *chat_templates.template_default; auto tools = json_value(body, "tools", json()); - auto has_tools = tools.is_array() && !tools.empty(); + auto stream = json_value(body, "stream", false); - if (has_tools) { - if (use_jinja) { - LOG_WRN("tools param is not fully supported yet\n"); - } else { + if (tools.is_array() && !tools.empty()) { + if (stream) { + throw std::runtime_error("Cannot use tools with stream"); + } + if (!use_jinja) { throw std::runtime_error("tools param requires --jinja flag"); } } + if (!use_jinja) { + if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) { + throw std::runtime_error("Unsupported param: tool_choice"); + } + } // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -619,7 +628,38 @@ static json oaicompat_completion_params_parse( // Apply chat template to the list of messages if (use_jinja) { - llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); + auto tool_choice = json_value(body, "tool_choice", std::string("auto")); + if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { + throw std::runtime_error("Invalid tool_choice: " + tool_choice); + } + if (tool_choice != "none" && llama_params.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + common_chat_inputs inputs; + inputs.messages = body.at("messages"); + inputs.tools = tools; + inputs.tool_choice = tool_choice; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.stream = stream; + // TODO: support mixing schema w/ tools beyond generic format. + inputs.json_schema = json_value(llama_params, "json_schema", json::object()); + auto chat_params = common_chat_params_init(tmpl, inputs); + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back({ + {"word", trigger.word}, + {"at_start", trigger.at_start}, + }); + } + llama_params["grammar_triggers"] = grammar_triggers; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); + } } else { llama_params["prompt"] = format_chat(tmpl, body.at("messages")); } @@ -638,14 +678,6 @@ static json oaicompat_completion_params_parse( throw std::runtime_error("top_logprobs requires logprobs to be set to true"); } - // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tool_choice" }; - for (const auto & param : unsupported_params) { - if (body.contains(param)) { - throw std::runtime_error("Unsupported param: " + param); - } - } - // Copy remaining properties to llama_params // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp diff --git a/include/llama.h b/include/llama.h index 3b75e7607..61907ed40 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1199,6 +1199,18 @@ extern "C" { const char * grammar_str, const char * grammar_root); + /// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639 + /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) diff --git a/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja b/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja new file mode 100644 index 000000000..f5baef30b --- /dev/null +++ b/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja @@ -0,0 +1,202 @@ + +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "List[" + json_to_python_type(json_spec.items) + "]"}} +{%- elif json_spec.type == "object" %} + {{- "Dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + +{%- macro old_tool_parser(tools) %} +{%- for tool in tools %} + {%- if loop.index0 != 0 %} + {{- '\n\n' }} + {%- endif %} + {{- '```python\ndef ' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameter_definitions|items %} + {%- if loop.index0 != 0 %} + {{- ', '}} + {%- endif %} + {{- param_name + ': ' }} + {%- if not param_fields.required %} + {{- 'Optional[' + param_fields.type + '] = None'}} + {%- else %} + {{- param_fields.type }} + {%- endif %} + {%- endfor %} + {{- ') -> List[Dict]:\n """'}} + {{- tool.description }} + {%- if tool.parameter_definitions|length != 0 %} + {{- '\n\n Args:\n '}} + {%- for param_name, param_fields in tool.parameter_definitions|items %} + {%- if loop.index0 != 0 %} + {{- '\n ' }} + {%- endif %} + {{- param_name + ' ('}} + {%- if not param_fields.required %} + {{- 'Optional[' + param_fields.type + ']'}} + {%- else %} + {{- param_fields.type }} + {%- endif %} + {{- '): ' + param_fields.description }} + {%- endfor %} + {%- endif %} + {{- '\n """\n pass\n```' }} +{%- endfor %} +{%- endmacro %} + +{%- macro new_tool_parser(tools) %} +{%- for tool in tools %} + {%- if loop.index0 != 0 %} + {{- '\n\n'}} + {%- endif %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{-'```python +def ' + tool.name + '('}} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.index0 != 0 %} + {{- ', '}} + {%- endif %} + {{-param_name + ": "}} + {%- if not param_name in tool.parameters.required %} + {{-'Optional[' + json_to_python_type(param_fields) + '] = None'}} + {%- else %} + {{- json_to_python_type(param_fields) }} + {%- endif %} + {%- endfor %} + {{- ') -> List[Dict]: + """'}} + {{- tool.description }} + {%- if tool.parameters.properties|length != 0 %} + {{- '\n\n Args:\n '}} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.index0 != 0 %} + {{- '\n ' }} + {%- endif %} + {{- param_name + ' ('}} + {%- if not param_name in tool.parameters.required %} + {{-'Optional[' + json_to_python_type(param_fields) + ']'}} + {%- else %} + {{- json_to_python_type(param_fields) }} + {%- endif %} + {{- '): ' + param_fields.description }} + {%- endfor %} + {%- endif %} + {{- '\n """\n pass\n```' }} +{%- endfor %} +{%- endmacro %} + +{{- bos_token }} +{%- if messages[0]['role'] == 'system' %} + {%- set loop_messages = messages[1:] %} + {%- set system_message = messages[0]['content'] %} +{%- else %} + {%- set loop_messages = messages %} + {%- set system_message = '## Task and Context\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user\'s needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.' %} +{%- endif %} +{{- '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' }} +{{- '# Safety Preamble' }} +{{- ' +The instructions in this section override those in the task description and style guide sections. Don\'t answer questions that are harmful or immoral.' }} +{{- ' + +# System Preamble' }} +{{- ' +## Basic Rules' }} +{{- ' +You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user\'s requests, you cite your sources in your answers, according to those instructions.' }} +{{- ' + +# User Preamble' }} +{{- ' +' + system_message }} +{{-' + +## Available Tools +Here is a list of tools that you have available to you: + +'}} +{%- set ns = namespace(new_tools=true) %} +{%- for tool in tools %} + {%- if tool.parameter_definitions is defined %} + {%- set ns.new_tools = false %} + {%- endif %} +{%- endfor %} +{%- if ns.new_tools %} + {{- new_tool_parser(tools) }} +{%- else %} + {{- old_tool_parser(tools) }} +{%- endif %} +{{- '<|END_OF_TURN_TOKEN|>'}} +{%- for message in loop_messages %} + {%- set content = message['content'] %} + {%- if message.role == 'user' %} + {{- '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content|trim + '<|END_OF_TURN_TOKEN|>' }} + {%- elif message.role == 'system' %} + {{- '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + content|trim + '<|END_OF_TURN_TOKEN|>' }} + {%- elif message.role == 'assistant' and message.tool_calls is defined %} + {{- '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }} + {%- if message.content is defined %} + {{- message.content|trim }} + {%- endif %} + {{- '\nAction:\n```json\n[\n' }} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{\n'|indent(4, first=true) }} + {{- '"tool_name": "'|indent(8, first=true) + tool_call.name + '",\n' }} + {{- '"parameters": '|indent(8, first=true) }} + {%- if tool_call.arguments is defined and tool_call.arguments|length > 0 %} + {{- tool_call.arguments|tojson(indent=4)|indent(8) }} + {{- '\n' }} + {%- else %} + {{- '{}\n' }} + {%- endif %} + {{- '}'|indent(4, first=true) }} + {%- if not loop.last %} + {{- ',\n' }} + {%- endif %} + {%- endfor %} + {{- "\n]```\n" }} + {%- elif message.role == 'assistant' %} + {{- '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content|trim + '<|END_OF_TURN_TOKEN|>' }} + {%- elif message.role == 'tool' %} + {{- '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\n' }} + {{- message.content|trim }} + {{- '<|END_OF_TURN_TOKEN|>' }} + {%- endif %} +{%- endfor %} +{{-'<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write \'Action:\' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user\'s last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: +```json +[ + { + "tool_name": title of the tool in the specification, + "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters + } +]```<|END_OF_TURN_TOKEN|>'}} +{%- if add_generation_prompt %} + {{- '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }} +{%- endif %} diff --git a/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja b/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja new file mode 100644 index 000000000..149250bd5 --- /dev/null +++ b/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja @@ -0,0 +1,152 @@ +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]"}} +{%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- else %} + {{- "dict" }} + {%- endif %} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- '<|im_start|>system +' }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + " + +" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args: +" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- " + Returns: + " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- " +" }} + {%- endif %} +{%- endfor %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|> +' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + ' +' + message.content + '<|im_end|>' + ' +' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- ' + +' }} {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {{- ', '}} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments|tojson }} + {%- endif %} + {%- endif %} + {{- '}' }} + {{- ' +' }} + {%- endfor %} + {{- '<|im_end|> +' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool +' }} + {%- endif %} + {{- ' +' }} + {{- message.content }} + {%- if not loop.last %} + {{- ' + +' }} + {%- else %} + {{- ' +' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant +' }} +{%- endif %} diff --git a/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja b/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja new file mode 100644 index 000000000..149250bd5 --- /dev/null +++ b/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja @@ -0,0 +1,152 @@ +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]"}} +{%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- else %} + {{- "dict" }} + {%- endif %} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- '<|im_start|>system +' }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + " + +" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args: +" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- " + Returns: + " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- " +" }} + {%- endif %} +{%- endfor %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|> +' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + ' +' + message.content + '<|im_end|>' + ' +' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- ' + +' }} {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {{- ', '}} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments|tojson }} + {%- endif %} + {%- endif %} + {{- '}' }} + {{- ' +' }} + {%- endfor %} + {{- '<|im_end|> +' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool +' }} + {%- endif %} + {{- ' +' }} + {{- message.content }} + {%- if not loop.last %} + {{- ' + +' }} + {%- else %} + {{- ' +' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant +' }} +{%- endif %} diff --git a/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja b/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja new file mode 100644 index 000000000..bdf7919a9 --- /dev/null +++ b/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja @@ -0,0 +1,54 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja new file mode 100644 index 000000000..02a1c3bce --- /dev/null +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja @@ -0,0 +1 @@ +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %} \ No newline at end of file diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja new file mode 100644 index 000000000..2ebfe7c1e --- /dev/null +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja @@ -0,0 +1,56 @@ +{% if not add_generation_prompt is defined %} +{% set add_generation_prompt = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %} +{%- for message in messages %} +{%- if message['role'] == 'system' %} +{% set ns.system_prompt = message['content'] %} +{%- endif %} +{%- endfor %} +{{bos_token}} +{{ns.system_prompt}} +{%- for message in messages %} +{%- if message['role'] == 'user' %} +{%- set ns.is_tool = false -%} +{{'<|User|>' + message['content']}} +{%- endif %} +{%- if message['role'] == 'assistant' and message['content'] is none %} +{%- set ns.is_tool = false -%} +{%- for tool in message['tool_calls']%} +{%- if not ns.is_first %} +{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} +{%- set ns.is_first = true -%} +{%- else %} +{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} +{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} +{%- endif %} +{%- endfor %} +{%- endif %} +{%- if message['role'] == 'assistant' and message['content'] is not none %} +{%- if ns.is_tool %} +{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} +{%- set ns.is_tool = false -%} +{%- else %} +{% set content = message['content'] %} +{% if '' in content %} +{% set content = content.split('')[-1] %} +{% endif %} +{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}} +{%- endif %} +{%- endif %} +{%- if message['role'] == 'tool' %} +{%- set ns.is_tool = true -%} +{%- if ns.is_output_first %} +{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} +{%- set ns.is_output_first = false %} +{%- else %} +{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} +{%- endif %} +{%- endif %} +{%- endfor -%} +{% if ns.is_tool %} +{{'<|tool▁outputs▁end|>'}} +{% endif %} +{% if add_generation_prompt and not ns.is_tool %} +{{'<|Assistant|>'}} +{% endif %} \ No newline at end of file diff --git a/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja b/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja new file mode 100644 index 000000000..9b8136df7 --- /dev/null +++ b/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja @@ -0,0 +1,57 @@ +{%- set loop_messages = messages -%} +{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%} +{%- set system_prompt_suffix -%} +{%- filter trim -%} +In addition to plain text responses, you can chose to call one or more of the provided functions. + +Use the following rule to decide when to call a function: + * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so + * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls + +If you decide to call functions: + * prefix function calls with functools marker (no closing marker required) + * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] + * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples + * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * make sure you pick the right functions that match the user intent + +Available functions as JSON spec: +{%- endfilter -%} +{%- endset -%} +{%- set system_prompt_suffix = system_prompt_suffix + "\n" + functions -%} +{%- set system_prompt_suffix = system_prompt_suffix + '\nToday is ' + datetime + '.' -%} +{%- set ns = namespace(role='', content='') -%} +{#- Basic consistency checks -#} +{%- if not loop_messages -%} + {{ raise_exception('Expected non-empty messages') }} +{%- endif -%} +{%- for message in loop_messages -%} + {%- set ns.role = message['role'] | lower -%} + {%- if ns.role not in message_roles -%} + {%- set message_roles_string = message_roles | join(', ') -%} + {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles_string + ' are supported.') }} + {%- endif -%} + {%- set msg_content = message['content'] | default('', true) | trim -%} + {%- if loop.index0 == 0 -%} + {%- if ns.role == 'system' -%} + {%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\n' + message['content'] | trim + '\n' + system_prompt_suffix + '<|eot_id|>' -%} + {%- else -%} + {%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\nYou are a helpful assistant with access to functions.\n' + system_prompt_suffix + '<|eot_id|>' -%} + {%- endif -%} + {%- set ns.content = bos_token + system_prompt -%} + {{- ns.content -}} + {%- endif -%} + {%- if loop.index0 > 0 or ns.role != 'system' -%} + {%- set ns.content = '<|start_header_id|>' + ns.role + '<|end_header_id|>\n\n' + msg_content -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- set tool = namespace(calls=[]) -%} + {%- for call in message['tool_calls'] -%} + {%- set tool.calls = tool.calls + ['{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}'] -%} + {%- endfor -%} + {%- set ns.content = ns.content + ' functools[' + tool.calls | join(', ') + ']' -%} + {%- endif -%} + {%- set ns.content = ns.content + '<|eot_id|>' -%} + {{- ns.content -}} + {%- endif -%} +{%- endfor -%} +{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} diff --git a/models/templates/google-gemma-2-2b-it.jinja b/models/templates/google-gemma-2-2b-it.jinja new file mode 100644 index 000000000..923ec253c --- /dev/null +++ b/models/templates/google-gemma-2-2b-it.jinja @@ -0,0 +1,4 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + ' +' + message['content'] | trim + ' +' }}{% endfor %}{% if add_generation_prompt %}{{'model +'}}{% endif %} \ No newline at end of file diff --git a/models/templates/meetkai-functionary-medium-v3.1.jinja b/models/templates/meetkai-functionary-medium-v3.1.jinja new file mode 100644 index 000000000..29d64a215 --- /dev/null +++ b/models/templates/meetkai-functionary-medium-v3.1.jinja @@ -0,0 +1,58 @@ +{# version=v3-llama3.1 #}{%- if not tools is defined -%} + {%- set tools = none -%} +{%- endif -%} + +{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%} +{%- if has_code_interpreter -%} + {%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%} +{%- endif -%} + +{#- System message + builtin tools #} +{{- bos_token + "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if has_code_interpreter %} + {{- "Environment: ipython\n\n" }} +{%- else -%} + {{ "\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n\n" }} +{%- if tools %} + {{- "\nYou have access to the following functions:\n\n" }} + {%- for t in tools %} + {%- if "type" in t -%} + {{ "Use the function '"|safe + t["function"]["name"] + "' to '"|safe + t["function"]["description"] + "'\n"|safe + t["function"] | tojson() }} + {%- else -%} + {{ "Use the function '"|safe + t["name"] + "' to '"|safe + t["description"] + "'\n"|safe + t | tojson() }} + {%- endif -%} + {{- "\n\n" }} + {%- endfor %} + {{- '\nThink very carefully before calling functions.\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- If looking for real time information use relevant functions before falling back to brave_search\n- Function calls MUST follow the specified format, start with \n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line\n\n' -}} +{%- endif %} +{{- "<|eot_id|>" -}} + +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>ipython<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call["function"]["name"] == "python" -%} + {{ '<|python_tag|>' + tool_call['function']['arguments'] }} + {%- else -%} + {{ '' + tool_call['function']['arguments'] + '' }} + {%- endif -%} + {%- endfor -%} + {{ '<|eom_id|>' }} + {%- else -%} + {{ '<|eot_id|>' }} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif -%} \ No newline at end of file diff --git a/models/templates/meetkai-functionary-medium-v3.2.jinja b/models/templates/meetkai-functionary-medium-v3.2.jinja new file mode 100644 index 000000000..74fd1e7af --- /dev/null +++ b/models/templates/meetkai-functionary-medium-v3.2.jinja @@ -0,0 +1,287 @@ +{# version=v3.llama3 #}{%- macro append_new_param_info(param_declaration, comment_info, examples_info, depth) -%} + {%- set offset = "" -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- if examples_info | length > 0 -%} + {# Append each example info #} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- endif -%} + {{ "\n" + offset + param_declaration }} +{%- endmacro -%} + +{%- macro convert_data_type(param_type) -%} + {%- if param_type == "integer" or param_type == "float" -%} + {{ "number" }} + {%- else -%} + {{ param_type }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_param_type(param) -%} + {%- set param_type = "any" -%} + + {%- if "type" in param -%} + {%- set raw_param_type = param["type"] -%} + {%- if raw_param_type is iterable and raw_param_type is not string -%} + {%- set param_type = raw_param_type | join(" | ") -%} + {%- else -%} + {%- set param_type = raw_param_type -%} + {%- endif -%} + {{ convert_data_type(param_type) }} + {%- elif "oneOf" in param -%} + {%- set one_of_types = param["oneOf"]|selectattr("type", "defined")|list -%} + {%- set one_of_types = one_of_types|map(attribute="type")|unique|list -%} + {{ convert_data_type(one_of_types | join(" | ")) }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_format_param(param) -%} + {%- if "format" in param -%} + {{ param["format"] }} + {%- elif "oneOf" in param -%} + {%- set formats = [] -%} + {%- for item in param["oneOf"] -%} + {%- if "format" in item -%} + {%- if item["format"] == param["oneOf"][-1]["format"] -%} + {{ item["format"] }} + {%- else -%} + {{ item["format"] + " or "}} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ "<|NONE|>" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_param_info(param) -%} + {%- set param_type = param.get("type", "any") -%} + {%- set format_param = get_format_param(param) -%} + + {%- if "description" in param or "default" in param or format_param != "<|NONE|>" or param["maximum"] or param["minimum"] or param["maxLength"] or param["minLength"] -%} + {{ "//" }} + {%- if "description" in param -%} + {%- set desc = param["description"] -%} + {%- if not desc.endswith(".") -%} + {%- set desc = desc + "." -%} + {%- endif -%} + {{ " " + desc }} + {%- endif -%} + + {%- if "default" in param -%} + {%- set default_value = param["default"] -%} + {%- if param_type == "string" -%} + {%- set default_value = '"' ~ default_value ~ '"' -%} + {%- endif -%} + {{ " Default=" ~ default_value ~ "." }} + {%- endif -%} + + {%- set format_param = get_format_param(param) -%} + {%- if format_param != "<|NONE|>" -%} + {{ " Format=" ~ format_param }} + {%- endif -%} + + {%- for field, field_name in [("maximum", "Maximum"), ("minimum", "Minimum"), ("maxLength", "Maximum length"), ("minLength", "Minimum length")] -%} + {%- if field in param -%} + {{ " " + field_name ~ "=" ~ param[field] }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ "<|NONE|>"}} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_enum_option_str(enum_options) -%} + {%- for v in enum_options -%} + {%- if v is string -%} + {{ '"' + v + '"' }} + {%- else -%} + {{ v }} + {%- endif -%} + {%- if enum_options|length > 0 and v != enum_options[-1] -%} + {{ " | " }} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} + +{%- macro get_array_typescript(param_name, param_dic, depth) -%} + {%- set offset = '' -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + {%- set items_info = param_dic.get('items', {}) -%} + + {%- if items_info|length == 0 -%} + {%- if param_name -%} + {{ "\n" + offset + param_name + ": []" }} + {%- else -%} + {{ "\n" + offset + "[]" }} + {%- endif -%} + {%- else -%} + {%- set array_type = get_param_type(items_info) -%} + {%- if array_type == 'object' -%} + {%- if param_name -%} + {{ "\n" + offset + param_name + ": {" }} + {%- else -%} + {{ "\n" + offset + "{" }} + {%- endif -%} + {{ get_parameter_typescript(items_info.get('properties', {}), items_info.get('required', []), depth + 1) -}} + {{- "\n" + offset + "}[]" }} + {%- elif array_type == 'array' -%} + {%- set item_info = get_array_typescript(None, items_info, depth + 1) -%} + {%- if not param_name -%} + {{ "\n" + item_info + "[]" }} + {%- else -%} + {{ "\n" + offset + param_name + ": " + item_info|trim + "[]" }} + {%- endif -%} + {%- else -%} + {%- if 'enum' in items_info -%} + {%- set item_type = get_enum_option_str(items_info['enum']) -%} + {%- if param_name is none -%} + {{ "(" + item_type + ")[]"}} + {%- else -%} + {{ "\n" + offset + param_name + ": (" + item_type + ")[]" }} + {%- endif -%} + {%- else -%} + {%- if param_name is none -%} + {{ "\n" + array_type + "[]" }} + {%- else -%} + {{ "\n" + offset + param_name + ": " + array_type + "[]," }} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_parameter_typescript(properties, required_params, depth=0) -%} + {%- set res = "" -%} + {%- for param_name, param in properties.items() -%} + {%- if param is mapping -%} + {%- set comment_info = get_param_info(param) -%} + {# Param Examples #} + {%- set examples_info = [] -%} + {%- if "examples" in param -%} + {%- set examples_info = ["Example " + param_name + ":"] -%} + {%- set examples_info = examples_info + param["examples"] -%} + {%- endif -%} + + {# Param Name declaration #} + {%- set param_declaration = param_name -%} + {%- if required_params is iterable and param_name not in required_params -%} + {%- set param_declaration = param_declaration + "?" -%} + {%- endif -%} + + {%- set param_type = get_param_type(param) -%} + + {# Handle indentation based on depth #} + {%- set offset = "" -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + + {%- if param_type == "object" -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- endif -%} + {%- if examples_info|length > 0 -%} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- set param_declaration = param_declaration + ": {" -%} + {{ "\n" + offset + param_declaration -}} + {{- get_parameter_typescript(param.get("properties", {}), param.get("required", []), depth + 1) -}} + {{- "\n" + offset + "}," }} + {%- elif param_type == "array" -%} + {%- set item_info = param.get("items", {}) -%} + {%- if "type" not in item_info -%} + {%- set param_declaration = param_declaration + ": []," -%} + {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }} + {%- else -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- endif -%} + {%- if examples_info|length > 0 -%} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- set array_declaration = get_array_typescript(param_declaration, param, depth) -%} + {%- if not array_declaration.endswith(",") -%} + {%- set array_declaration = array_declaration + "," -%} + {%- endif -%} + {{ array_declaration}} + {%- endif -%} + {%- else -%} + {%- if "enum" in param -%} + {%- set param_type = get_enum_option_str(param["enum"]) -%} + {%- endif -%} + {%- if "nullable" in param and param["nullable"] -%} + {%- set param_type = param_type + " | null" -%} + {%- endif -%} + {%- set param_declaration = param_declaration + ": " + param_type + "," -%} + {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} + +{%- macro generate_schema_from_functions(functions, namespace='functions') -%} + {{ "// Supported function definitions that should be called when necessary.\n" -}} + {{- "namespace " + namespace + " {\n\n" -}} + + {%- for function in functions -%} + {%- if function.get("function") -%} + {%- set function = function.get("function") -%} + {%- endif -%} + + {%- set function_name = function.get("name") -%} + {%- if function_name -%} + {%- set description = function.get('description', '') -%} + {%- set parameters = function.get('parameters', {}) -%} + {{- "// " + description + "\n" -}} + {{- "type " + function_name -}} + {%- if parameters and parameters.get("properties") -%} + {{- " = (_: {" -}} + {%- set required_params = parameters.get("required", []) -%} + {{ get_parameter_typescript(parameters.get("properties"), required_params, 0) -}} + {{- "\n}) => any;\n\n" }} + {%- else -%} + {{ " = () => any;\n\n" }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {{ "} // namespace " + namespace }} +{%- endmacro -%} +{%- if not tools -%} + {%- set tools = [] -%} +{%- endif -%} +{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\nYou are capable of executing available function(s) if required.\nOnly execute function(s) when absolutely necessary.\nAsk for the required input to:recipient==all\nUse JSON for function arguments.\nRespond in this format:\n>>>${recipient}\n${content}\nAvailable functions:\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}} +{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%} + {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ '>>>all\n' + message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }} + {%- endfor -%} + {%- endif -%} + {{ '<|eot_id|>' }} + {%- endif -%} +{%- endfor -%} +{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n>>>' }}{% endif %} \ No newline at end of file diff --git a/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja b/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja new file mode 100644 index 000000000..33089ace1 --- /dev/null +++ b/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja @@ -0,0 +1,109 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message + builtin tools #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- endif %} + {%- if builtin_tools is defined %} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja b/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja new file mode 100644 index 000000000..1bad6a0f6 --- /dev/null +++ b/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja @@ -0,0 +1,93 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {{- "<|eot_id|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja b/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja new file mode 100644 index 000000000..33089ace1 --- /dev/null +++ b/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja @@ -0,0 +1,109 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message + builtin tools #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- endif %} + {%- if builtin_tools is defined %} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/models/templates/microsoft-Phi-3.5-mini-instruct.jinja b/models/templates/microsoft-Phi-3.5-mini-instruct.jinja new file mode 100644 index 000000000..d1533d152 --- /dev/null +++ b/models/templates/microsoft-Phi-3.5-mini-instruct.jinja @@ -0,0 +1,8 @@ +{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'user' %}{{'<|user|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|> +' + message['content'] + '<|end|> +'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> +' }}{% else %}{{ eos_token }}{% endif %} \ No newline at end of file diff --git a/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja b/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja new file mode 100644 index 000000000..9c21a3f13 --- /dev/null +++ b/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja @@ -0,0 +1,87 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{#- This block checks for alternating user/assistant messages, skipping tool calling messages #} +{%- set ns = namespace() %} +{%- set ns.index = 0 %} +{%- for message in loop_messages %} + {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %} + {%- if (message["role"] == "user") != (ns.index % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} + {%- set ns.index = ns.index + 1 %} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS][" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST]" + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST]" + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif (message.tool_calls is defined and message.tool_calls is not none) %} + {{- "[TOOL_CALLS][" }} + {%- for tool_call in message.tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length != 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- endif %} + {{- ', "id": "' + tool_call.id + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- message["content"] + eos_token}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS]{"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/scripts/fetch_server_test_models.py b/scripts/fetch_server_test_models.py new file mode 100755 index 000000000..05690b138 --- /dev/null +++ b/scripts/fetch_server_test_models.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +''' + This script fetches all the models used in the server tests. + + This is useful for slow tests that use larger models, to avoid them timing out on the model downloads. + + It is meant to be run from the root of the repository. + + Example: + python scripts/fetch_server_test_models.py + ( cd examples/server/tests && ./tests.sh -v -x -m slow ) +''' +import ast +import glob +import logging +import os +from typing import Generator +from pydantic import BaseModel +from typing import Optional +import subprocess + + +class HuggingFaceModel(BaseModel): + hf_repo: str + hf_file: Optional[str] = None + + class Config: + frozen = True + + +def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]: + try: + with open(test_file) as f: + tree = ast.parse(f.read()) + except Exception as e: + logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}') + return + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + for dec in node.decorator_list: + if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize': + param_names = ast.literal_eval(dec.args[0]).split(",") + if "hf_repo" not in param_names: + continue + + raw_param_values = dec.args[1] + if not isinstance(raw_param_values, ast.List): + logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}') + continue + + hf_repo_idx = param_names.index("hf_repo") + hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None + + for t in raw_param_values.elts: + if not isinstance(t, ast.Tuple): + logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}') + continue + yield HuggingFaceModel( + hf_repo=ast.literal_eval(t.elts[hf_repo_idx]), + hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + + models = sorted(list(set([ + model + for test_file in glob.glob('examples/server/tests/unit/test_*.py') + for model in collect_hf_model_test_parameters(test_file) + ])), key=lambda m: (m.hf_repo, m.hf_file)) + + logging.info(f'Found {len(models)} models in parameterized tests:') + for m in models: + logging.info(f' - {m.hf_repo} / {m.hf_file}') + + cli_path = os.environ.get( + 'LLAMA_SERVER_BIN_PATH', + os.path.join( + os.path.dirname(__file__), + '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli')) + + for m in models: + if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file): + continue + if m.hf_file is not None and '-of-' in m.hf_file: + logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file') + continue + logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched') + cmd = [ + cli_path, + '-hfr', m.hf_repo, + *([] if m.hf_file is None else ['-hff', m.hf_file]), + '-n', '1', + '-p', 'Hey', + '--no-warmup', + '--log-disable', + '-no-cnv'] + if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo: + cmd.append('-fa') + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError: + logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}') + exit(1) diff --git a/scripts/get_hf_chat_template.py b/scripts/get_chat_template.py old mode 100755 new mode 100644 similarity index 86% rename from scripts/get_hf_chat_template.py rename to scripts/get_chat_template.py index 23bb1de59..e8982d11a --- a/scripts/get_hf_chat_template.py +++ b/scripts/get_chat_template.py @@ -4,12 +4,12 @@ If a model has multiple chat templates, you can specify the variant name. Syntax: - ./scripts/get_hf_chat_template.py model_id [variant] + ./scripts/get_chat_template.py model_id [variant] Examples: - ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct - ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use - ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct + ./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct + ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use + ./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct ''' import json @@ -17,7 +17,7 @@ import re import sys -def get_hf_chat_template(model_id, variant=None): +def get_chat_template(model_id, variant=None): try: # Use huggingface_hub library if available. # Allows access to gated models if the user has access and ran `huggingface-cli login`. @@ -69,7 +69,7 @@ def main(args): model_id = args[0] variant = None if len(args) < 2 else args[1] - template = get_hf_chat_template(model_id, variant) + template = get_chat_template(model_id, variant) sys.stdout.write(template) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index bebe4e9a3..6be5cbe0e 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) { } } } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); rules.clear(); return false; } @@ -960,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy =*/ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_tokens = */ {}, + /* .trigger_words = */ {}, + }; } -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it @@ -1035,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, } } while (true); + std::vector vec_trigger_tokens; + std::vector vec_trigger_words; + for (size_t i = 0; i < num_trigger_tokens; i++) { + GGML_ASSERT(trigger_tokens != nullptr); + vec_trigger_tokens.push_back(trigger_tokens[i]); + } + for (size_t i = 0; i < num_trigger_words; i++) { + GGML_ASSERT(trigger_words != nullptr); + vec_trigger_words.push_back(trigger_words[i]); + } + // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + std::move(vec_trigger_tokens), + std::move(vec_trigger_words), + }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1055,6 +1094,11 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.rules, grammar.stacks, grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_tokens, + grammar.trigger_words, }; // redirect elements in stacks to point to new rules @@ -1076,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { GGML_ASSERT(grammar.vocab != nullptr); + if (grammar.awaiting_trigger) { + return; + } + bool allow_eog = false; for (const auto & stack : grammar.stacks) { if (stack.empty()) { @@ -1115,6 +1163,34 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); + const auto & piece = grammar.vocab->token_to_piece(token); + + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, piece); + LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); + return; + } else { + // TODO: consider a smarter incremental substring search algorithm (store last position to search from). + grammar.trigger_buffer += piece; + for (const auto & word : grammar.trigger_words) { + auto pos = grammar.trigger_buffer.find(word); + if (pos != std::string::npos) { + grammar.awaiting_trigger = false; + auto constrained_str = grammar.trigger_buffer.substr(pos); + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, constrained_str); + LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); + return; + } + } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str()); + return; + } + } + if (grammar.vocab->is_eog(token)) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { @@ -1124,8 +1200,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab->token_to_piece(token); + llama_grammar_accept_str(grammar, piece); +} +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f8b40c651..252d54d4c 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -114,6 +114,15 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + // lazy grammars wait for trigger words or tokens before constraining the sampling. + // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // (useful e.g. for tool_choice=required) + bool lazy = false; + bool awaiting_trigger = false; // Initialized to true for lazy grammars only + std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). + std::vector trigger_words; }; // @@ -127,7 +136,15 @@ struct llama_grammar * llama_grammar_init_impl( size_t n_rules, size_t start_rule_index); -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); void llama_grammar_free_impl(struct llama_grammar * grammar); @@ -141,3 +158,7 @@ void llama_grammar_apply_impl( void llama_grammar_accept_impl( struct llama_grammar & grammar, llama_token token); + +void llama_grammar_accept_str( + struct llama_grammar & grammar, + const std::string & piece); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index b3a12386e..26974f539 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1433,13 +1433,30 @@ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token } } +// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle. +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (!ctx->grammar) { return; } - auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); + std::vector trigger_words; + for (auto & word : ctx->grammar->trigger_words) { + trigger_words.push_back(word.c_str()); + } + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + ctx->grammar->lazy, trigger_words.data(), trigger_words.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new; @@ -1448,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr); + auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); // copy the state { @@ -1484,7 +1501,15 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .free = */ llama_sampler_grammar_free, }; -struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { auto * ctx = new llama_sampler_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { @@ -1492,7 +1517,7 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * voc /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root), + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), }; } else { *ctx = { @@ -1509,6 +1534,24 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * voc }; } +struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens); +} + // penalties struct llama_sampler_penalties { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3fa43c295..40f83ff0d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -93,6 +93,7 @@ if (NOT WIN32) llama_target_and_test(test-grammar-parser.cpp) llama_target_and_test(test-grammar-integration.cpp) llama_target_and_test(test-llama-grammar.cpp) + llama_target_and_test(test-chat.cpp) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 190643136..4563f9dcb 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -328,7 +328,7 @@ int main(void) { // test llama_chat_format_single for system message printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); std::vector chat2; - common_chat_msg sys_msg{"system", "You are a helpful assistant"}; + common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; auto fmt_sys = [&](std::string tmpl_str) { minja::chat_template tmpl(tmpl_str, "", ""); @@ -352,10 +352,10 @@ int main(void) { // test llama_chat_format_single for user message printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); - chat2.push_back({"system", "You are a helpful assistant"}); - chat2.push_back({"user", "Hello"}); - chat2.push_back({"assistant", "I am assistant"}); - common_chat_msg new_msg{"user", "How are you"}; + chat2.push_back({"system", "You are a helpful assistant", {}}); + chat2.push_back({"user", "Hello", {}}); + chat2.push_back({"assistant", "I am assistant", {}}); + common_chat_msg new_msg{"user", "How are you", {}}; auto fmt_single = [&](std::string tmpl_str) { minja::chat_template tmpl(tmpl_str, "", ""); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp new file mode 100644 index 000000000..ccc65d87a --- /dev/null +++ b/tests/test-chat.cpp @@ -0,0 +1,521 @@ +// Tests chat handling, including grammar generation and parsing for tool calling, for various templates. +// +// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates, +// e.g. given Minja (http://github.com/google/minja) checked out in parent dir: +// +// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null +// +#include +#include +#include +#include + +#include "chat-template.hpp" +#include "chat.hpp" +#include "llama-grammar.h" +#include "unicode.h" + +using json = nlohmann::ordered_json; + +static common_chat_msg msg_from_json(const json & message) { + common_chat_msg ret{ + "assistant", + "", + {}, + }; + if (message.contains("content") && !message.at("content").is_null()) { + ret.content = message.at("content").get(); + } + auto has_tool_calls = message.contains("tool_calls"); + if (has_tool_calls) { + for (const auto & tc : message.at("tool_calls")) { + const auto & arguments = tc.at("function").at("arguments"); + ret.tool_calls.push_back({ + tc.at("function").at("name").get(), + arguments.is_string() ? arguments.get() : arguments.dump(), + tc.contains("id") ? tc.at("id").get() : "", + }); + } + } + return ret; +} + +template static void assert_equals(const T & expected, const T & actual) { + if (expected != actual) { + std::cerr << "Expected: " << expected << std::endl; + std::cerr << "Actual: " << actual << std::endl; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +static std::string read_file(const std::string & path) { + std::cerr << "# Reading: " << path << std::endl << std::flush; + std::ifstream fs(path, std::ios_base::binary); + if (!fs.is_open()) { + fs = std::ifstream("../" + path, std::ios_base::binary); + if (!fs.is_open()) { + throw std::runtime_error("Failed to open file: " + path); + } + } + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + std::string out; + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); + return out; +} + +static std::unique_ptr build_grammar(const std::string & grammar_str) { + return std::unique_ptr( + llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0)); +} + +// TODO: extract to common helper (copied from test-grammar-integration.cpp) +static bool match_string(const std::string & input, llama_grammar * grammar) { + const auto cpts = unicode_cpts_from_utf8(input); + + auto & stacks_cur = llama_grammar_get_stacks(grammar); + + for (const auto & cpt : cpts) { + llama_grammar_accept(grammar, cpt); + + if (stacks_cur.empty()) { + // no stacks means that the grammar failed to match at this point + return false; + } + } + + for (const auto & stack : stacks_cur) { + if (stack.empty()) { + // An empty stack means that the grammar has been completed + return true; + } + } + + return false; +} + +// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`. +static std::string dump(const json & j) { + return minja::Value(j).dump(-1, /* to_json= */ true); +} + +static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { + assert_equals(expected.role, actual.role); + assert_equals(expected.content, actual.content); + assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); + for (size_t i = 0; i < expected.tool_calls.size(); i++) { + const auto & expected_tool_call = expected.tool_calls[i]; + const auto & actual_tool_call = actual.tool_calls[i]; + assert_equals(expected_tool_call.name, actual_tool_call.name); + assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments))); + assert_equals(expected_tool_call.id, actual_tool_call.id); + } +} + +const auto special_function_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "special_function", + "description": "I'm special", + "parameters": { + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + } + }, + "required": ["arg1"] + } + } +})"); +const auto python_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "python", + "description": "an ipython interpreter", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + } + } +})"); +const auto code_interpreter_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "code_interpreter", + "description": "an ipython interpreter", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + } + } +})"); +const json tools = { special_function_tool, python_tool }; +const json llama_3_1_tools = { special_function_tool, code_interpreter_tool }; + +struct delta_data { + std::string delta; + std::string grammar; + common_chat_format format; +}; + +static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, + const json & user_message, const json & delta_message, const json & tools, + const json & tool_choice) { + common_chat_inputs inputs; + inputs.parallel_tool_calls = true; + inputs.messages = json::array(); + inputs.messages.push_back(user_message); + inputs.tools = tools; + inputs.tool_choice = tool_choice; + auto params_prefix = common_chat_params_init(tmpl, inputs); + + inputs.messages.push_back(delta_message); + inputs.add_generation_prompt = false; + auto params_full = common_chat_params_init(tmpl, inputs); + + std::string prefix = params_prefix.prompt; + std::string full = params_full.prompt; + + // Check full starts with prefix + if (full.find(prefix) != 0) { + fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str()); + throw std::runtime_error("Full message does not start with prefix"); + } + + if (full == prefix) { + throw std::runtime_error("Full message is the same as the prefix"); + } + + auto delta = full.substr(prefix.size()); + + // Strip end tokens + for (const auto & end_token : end_tokens) { + // rfind to find the last occurrence + auto pos = delta.rfind(end_token); + if (pos != std::string::npos) { + delta = delta.substr(0, pos); + break; + } + } + return { delta, params_full.grammar, params_full.format }; +} + +/* + Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false, + gets the diff, removes any end tokens and parses the result w/ the grammar, checking that + the parsed message is the same as the test_message +*/ +static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, + const json & test_message, const json & tools = {}, const std::string & expected_delta = "", + bool skip_grammar_test = false, bool skip_parser_test = false) { + common_chat_msg expected_msg = msg_from_json(test_message); + + auto user_message = json{ + { "role", "user" }, + { "content", "Hello, world!" } + }; + + for (const auto & tool_choice : json({ "auto", "required" })) { + auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice); + if (!expected_delta.empty()) { + assert_equals(expected_delta, data.delta); + } + + if (!skip_parser_test) { + const auto msg = common_chat_parse(data.delta, data.format); + assert_msg_equals(expected_msg, msg); + } + + if (!expected_msg.tool_calls.empty()) { + GGML_ASSERT(!data.grammar.empty()); + } + if (!data.grammar.empty()) { + auto grammar = build_grammar(data.grammar); + if (!grammar) { + throw std::runtime_error("Failed to build grammar"); + } + // TODO: exercice lazy grammars + triggers here, instead of skipping the test + if (!skip_grammar_test) { + if (!match_string(data.delta, grammar.get())) { + throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta + + "\n\nGrammar: " + data.grammar); + } + } + } + } +} + +static void test_template_output_parsers() { + auto text_message = json{ + { "role", "assistant" }, + { "content", "Hello, world!" }, + }; + auto tool_call_message = json{ + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", json{ { + { "type", "function" }, + { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, + } } } + }; + auto tool_call_message_with_id = json::parse(tool_call_message.dump()); + tool_call_message_with_id["tool_calls"][0]["id"] = "123456789"; + + auto python_tool_call_message = json{ + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", json{ { + { "type", "function" }, + { "function", + { + { "name", "python" }, + { "arguments", + { + { "code", "print('hey')" }, + } }, + } }, + } } } + }; + auto code_interpreter_tool_call_message = json{ + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", json{ { + { "type", "function" }, + { "function", + { + { "name", "code_interpreter" }, + { "arguments", + { + { "code", "print('hey')" }, + } }, + } }, + } } } + }; + + common_chat_inputs inputs_no_tools; + inputs_no_tools.messages = { + { { "role", "user" }, { "content", "Hey" } } + }; + + common_chat_inputs inputs_tools = inputs_no_tools; + inputs_tools.tools = json::array(); + inputs_tools.tools.push_back(special_function_tool); + + common_chat_inputs inputs_tools_builtin = inputs_no_tools; + inputs_tools_builtin.tools = json::array(); + inputs_tools_builtin.tools.push_back(python_tool); + + { + const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "", ""); + std::vector end_tokens{ "" }; + + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, + common_chat_params_init( + common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), + "", ""), + inputs_tools) + .format); + + // Generic tool calls doesn't generate / parse content-only messages symmetrically. + + assert_msg_equals(msg_from_json(text_message), + common_chat_parse("{\n" + " \"response\": \"Hello, world!\"\n" + "}", + common_chat_params_init(tmpl, inputs_tools).format)); + test_template(tmpl, end_tokens, tool_call_message_with_id, tools, + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}"); + } + { + const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", + ""); + std::vector end_tokens{ "" }; + + assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format); + + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template( + tmpl, end_tokens, tool_call_message_with_id, tools, + "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]", + /* skip_grammar_test= */ true); + } + { + const common_chat_template tmpl( + read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + std::vector end_tokens{ "<|im_end|>" }; + + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals( + COMMON_CHAT_FORMAT_HERMES_2_PRO, + common_chat_params_init( + common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), + "", ""), + inputs_tools) + .format); + assert_equals( + COMMON_CHAT_FORMAT_HERMES_2_PRO, + common_chat_params_init( + common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), + inputs_tools) + .format); + + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + ""); + test_template(tmpl, end_tokens, python_tool_call_message, tools, + "\n" + "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" + ""); + } + { + const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", + ""); + std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, + common_chat_params_init(tmpl, inputs_tools_builtin).format); + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, + common_chat_params_init( + common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), + "", ""), + inputs_tools_builtin) + .format); + + // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, + "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); + test_template(tmpl, end_tokens, python_tool_call_message, tools, + "<|python_tag|>python.call(code=\"print('hey')\")"); + test_template(tmpl, end_tokens, tool_call_message, tools, + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); + } + { + const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", + ""); + std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); + + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); + } + { + const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "", + ""); + std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, + common_chat_params_init(tmpl, inputs_tools).format); + + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "{\"arg1\": 1}"); + } + { + const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "", + ""); + std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; + + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format); + + test_template(tmpl, end_tokens, text_message, {}, + "all\n" + "Hello, world!", + /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "special_function\n" + "{\"arg1\": 1}"); + } + { + const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", + ""); + std::vector end_tokens{ "<|eot_id|>" }; + + assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format); + + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); + } + { + const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), + "", ""); + std::vector end_tokens{ "<|end▁of▁sentence|>" }; + + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); + + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, tool_call_message, tools, + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|>"); + } +} + +int main(int argc, char ** argv) { +#ifndef _WIN32 + if (argc > 1) { + common_chat_inputs inputs; + inputs.messages = { + { { "role", "user" }, { "content", "Hey" } } + }; + inputs.tools = json::array({ special_function_tool }); + + std::cout << "| Template | Format |\n"; + std::cout << "|----------|--------|\n"; + + for (int i = 1; i < argc; i++) { + std::string path = argv[i]; + if (path.rfind(".jinja") != path.size() - 6) { + std::cerr << "Skipping non-jinja file: " << path << std::endl; + continue; + } + common_chat_template tmpl(read_file(path), "", ""); + auto parts = string_split(path, "/"); + auto name = parts[parts.size() - 1]; + std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format) + << " |\n"; + } + } else +#endif + { + test_template_output_parsers(); + std::cout << "\n[chat] All tests passed!" << std::endl; + } + return 0; +} diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index e1bdbb925..288e08f51 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -13,7 +13,7 @@ using json = nlohmann::ordered_json; static llama_grammar * build_grammar(const std::string & grammar_str) { - return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0); } static bool test_build_grammar_fails(const std::string & grammar_str) {