llama.cpp/tests/test-chat.cpp
Olivier Chafik 8b576b6c55
Tool call support (generic + native for Llama, Functionary, Hermes, Mistral, Firefunction, DeepSeek) w/ lazy grammars (#9639)
---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-01-30 19:13:58 +00:00

522 lines
21 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
//
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
//
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
//
#include <fstream>
#include <iostream>
#include <json.hpp>
#include <string>
#include "chat-template.hpp"
#include "chat.hpp"
#include "llama-grammar.h"
#include "unicode.h"
using json = nlohmann::ordered_json;
static common_chat_msg msg_from_json(const json & message) {
common_chat_msg ret{
"assistant",
"",
{},
};
if (message.contains("content") && !message.at("content").is_null()) {
ret.content = message.at("content").get<std::string>();
}
auto has_tool_calls = message.contains("tool_calls");
if (has_tool_calls) {
for (const auto & tc : message.at("tool_calls")) {
const auto & arguments = tc.at("function").at("arguments");
ret.tool_calls.push_back({
tc.at("function").at("name").get<std::string>(),
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tc.contains("id") ? tc.at("id").get<std::string>() : "",
});
}
}
return ret;
}
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static std::string read_file(const std::string & path) {
std::cerr << "# Reading: " << path << std::endl << std::flush;
std::ifstream fs(path, std::ios_base::binary);
if (!fs.is_open()) {
fs = std::ifstream("../" + path, std::ios_base::binary);
if (!fs.is_open()) {
throw std::runtime_error("Failed to open file: " + path);
}
}
fs.seekg(0, std::ios_base::end);
auto size = fs.tellg();
fs.seekg(0);
std::string out;
out.resize(static_cast<size_t>(size));
fs.read(&out[0], static_cast<std::streamsize>(size));
return out;
}
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
return std::unique_ptr<llama_grammar>(
llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
}
// TODO: extract to common helper (copied from test-grammar-integration.cpp)
static bool match_string(const std::string & input, llama_grammar * grammar) {
const auto cpts = unicode_cpts_from_utf8(input);
auto & stacks_cur = llama_grammar_get_stacks(grammar);
for (const auto & cpt : cpts) {
llama_grammar_accept(grammar, cpt);
if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}
for (const auto & stack : stacks_cur) {
if (stack.empty()) {
// An empty stack means that the grammar has been completed
return true;
}
}
return false;
}
// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
static std::string dump(const json & j) {
return minja::Value(j).dump(-1, /* to_json= */ true);
}
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
assert_equals(expected.role, actual.role);
assert_equals(expected.content, actual.content);
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
const auto & expected_tool_call = expected.tool_calls[i];
const auto & actual_tool_call = actual.tool_calls[i];
assert_equals(expected_tool_call.name, actual_tool_call.name);
assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
assert_equals(expected_tool_call.id, actual_tool_call.id);
}
}
const auto special_function_tool = json::parse(R"({
"type": "function",
"function": {
"name": "special_function",
"description": "I'm special",
"parameters": {
"type": "object",
"properties": {
"arg1": {
"type": "integer",
"description": "The arg."
}
},
"required": ["arg1"]
}
}
})");
const auto python_tool = json::parse(R"({
"type": "function",
"function": {
"name": "python",
"description": "an ipython interpreter",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python code to execute."
}
},
"required": ["code"]
}
}
})");
const auto code_interpreter_tool = json::parse(R"({
"type": "function",
"function": {
"name": "code_interpreter",
"description": "an ipython interpreter",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python code to execute."
}
},
"required": ["code"]
}
}
})");
const json tools = { special_function_tool, python_tool };
const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
struct delta_data {
std::string delta;
std::string grammar;
common_chat_format format;
};
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & user_message, const json & delta_message, const json & tools,
const json & tool_choice) {
common_chat_inputs inputs;
inputs.parallel_tool_calls = true;
inputs.messages = json::array();
inputs.messages.push_back(user_message);
inputs.tools = tools;
inputs.tool_choice = tool_choice;
auto params_prefix = common_chat_params_init(tmpl, inputs);
inputs.messages.push_back(delta_message);
inputs.add_generation_prompt = false;
auto params_full = common_chat_params_init(tmpl, inputs);
std::string prefix = params_prefix.prompt;
std::string full = params_full.prompt;
// Check full starts with prefix
if (full.find(prefix) != 0) {
fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
throw std::runtime_error("Full message does not start with prefix");
}
if (full == prefix) {
throw std::runtime_error("Full message is the same as the prefix");
}
auto delta = full.substr(prefix.size());
// Strip end tokens
for (const auto & end_token : end_tokens) {
// rfind to find the last occurrence
auto pos = delta.rfind(end_token);
if (pos != std::string::npos) {
delta = delta.substr(0, pos);
break;
}
}
return { delta, params_full.grammar, params_full.format };
}
/*
Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
the parsed message is the same as the test_message
*/
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
bool skip_grammar_test = false, bool skip_parser_test = false) {
common_chat_msg expected_msg = msg_from_json(test_message);
auto user_message = json{
{ "role", "user" },
{ "content", "Hello, world!" }
};
for (const auto & tool_choice : json({ "auto", "required" })) {
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
if (!expected_delta.empty()) {
assert_equals(expected_delta, data.delta);
}
if (!skip_parser_test) {
const auto msg = common_chat_parse(data.delta, data.format);
assert_msg_equals(expected_msg, msg);
}
if (!expected_msg.tool_calls.empty()) {
GGML_ASSERT(!data.grammar.empty());
}
if (!data.grammar.empty()) {
auto grammar = build_grammar(data.grammar);
if (!grammar) {
throw std::runtime_error("Failed to build grammar");
}
// TODO: exercice lazy grammars + triggers here, instead of skipping the test
if (!skip_grammar_test) {
if (!match_string(data.delta, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
"\n\nGrammar: " + data.grammar);
}
}
}
}
}
static void test_template_output_parsers() {
auto text_message = json{
{ "role", "assistant" },
{ "content", "Hello, world!" },
};
auto tool_call_message = json{
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", json{ {
{ "type", "function" },
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
} } }
};
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
auto python_tool_call_message = json{
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", json{ {
{ "type", "function" },
{ "function",
{
{ "name", "python" },
{ "arguments",
{
{ "code", "print('hey')" },
} },
} },
} } }
};
auto code_interpreter_tool_call_message = json{
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", json{ {
{ "type", "function" },
{ "function",
{
{ "name", "code_interpreter" },
{ "arguments",
{
{ "code", "print('hey')" },
} },
} },
} } }
};
common_chat_inputs inputs_no_tools;
inputs_no_tools.messages = {
{ { "role", "user" }, { "content", "Hey" } }
};
common_chat_inputs inputs_tools = inputs_no_tools;
inputs_tools.tools = json::array();
inputs_tools.tools.push_back(special_function_tool);
common_chat_inputs inputs_tools_builtin = inputs_no_tools;
inputs_tools_builtin.tools = json::array();
inputs_tools_builtin.tools.push_back(python_tool);
{
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<end_of_turn>" };
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
common_chat_params_init(
common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
"<s>", "</s>"),
inputs_tools)
.format);
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message),
common_chat_parse("{\n"
" \"response\": \"Hello, world!\"\n"
"}",
common_chat_params_init(tmpl, inputs_tools).format));
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
"{\n"
" \"tool_calls\": [\n"
" {\n"
" \"name\": \"special_function\",\n"
" \"arguments\": {\n"
" \"arg1\": 1\n"
" },\n"
" \"id\": \"123456789\"\n"
" }\n"
" ]\n"
"}");
}
{
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
"</s>");
std::vector<std::string> end_tokens{ "</s>" };
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(
tmpl, end_tokens, tool_call_message_with_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
/* skip_grammar_test= */ true);
}
{
const common_chat_template tmpl(
read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<|im_end|>" };
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(
COMMON_CHAT_FORMAT_HERMES_2_PRO,
common_chat_params_init(
common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
"<s>", "</s>"),
inputs_tools)
.format);
assert_equals(
COMMON_CHAT_FORMAT_HERMES_2_PRO,
common_chat_params_init(
common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
inputs_tools)
.format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"</tool_call>");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
"</s>");
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
common_chat_params_init(tmpl, inputs_tools_builtin).format);
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
common_chat_params_init(
common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
"<s>", "</s>"),
inputs_tools_builtin)
.format);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
"<|python_tag|>python.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
"</s>");
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
"</s>");
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
{
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
"</s>");
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, {},
"all\n"
"Hello, world!",
/* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"special_function\n"
"{\"arg1\": 1}");
}
{
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
"</s>");
std::vector<std::string> end_tokens{ "<|eot_id|>" };
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
{
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
"<s>", "</s>");
std::vector<std::string> end_tokens{ "<end▁of▁sentence>" };
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n"
"{\"arg1\": 1}\n"
"```<tool▁call▁end>");
}
}
int main(int argc, char ** argv) {
#ifndef _WIN32
if (argc > 1) {
common_chat_inputs inputs;
inputs.messages = {
{ { "role", "user" }, { "content", "Hey" } }
};
inputs.tools = json::array({ special_function_tool });
std::cout << "| Template | Format |\n";
std::cout << "|----------|--------|\n";
for (int i = 1; i < argc; i++) {
std::string path = argv[i];
if (path.rfind(".jinja") != path.size() - 6) {
std::cerr << "Skipping non-jinja file: " << path << std::endl;
continue;
}
common_chat_template tmpl(read_file(path), "", "");
auto parts = string_split(path, "/");
auto name = parts[parts.size() - 1];
std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
<< " |\n";
}
} else
#endif
{
test_template_output_parsers();
std::cout << "\n[chat] All tests passed!" << std::endl;
}
return 0;
}