From 96952e7181929c6001b2bc69a33f240de731cc3a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 24 Jul 2024 13:48:46 +0200 Subject: [PATCH] llama : fix `llama_chat_format_single` for mistral (#8657) * fix `llama_chat_format_single` for mistral * fix typo * use printf --- common/common.cpp | 2 +- examples/main/main.cpp | 1 + tests/test-chat-template.cpp | 30 ++++++++++++++++++++++++------ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 4c19132f1..ec44a0552 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2723,7 +2723,7 @@ std::string llama_chat_format_single(const struct llama_model * model, const llama_chat_msg & new_msg, bool add_ass) { std::ostringstream ss; - auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false); + auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a0d817b1a..61e960ea2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -124,6 +124,7 @@ static std::string chat_add_and_format(struct llama_model * model, std::vectorchat_template, chat_msgs, new_msg, role == "user"); chat_msgs.push_back({role, content}); + LOG("formatted: %s\n", formatted.c_str()); return formatted; } diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 6583dd0b2..46a7d3aea 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -133,13 +132,31 @@ int main(void) { ); formatted_chat.resize(res); std::string output(formatted_chat.data(), formatted_chat.size()); - std::cout << output << "\n-------------------------\n"; + printf("%s\n", output.c_str()); + printf("-------------------------\n"); assert(output == expected); } - // test llama_chat_format_single - std::cout << "\n\n=== llama_chat_format_single ===\n\n"; + + // test llama_chat_format_single for system message + printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); std::vector chat2; + llama_chat_msg sys_msg{"system", "You are a helpful assistant"}; + + auto fmt_sys = [&](std::string tmpl) { + auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false); + printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); + printf("-------------------------\n", output.c_str()); + return output; + }; + assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n"); + assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n"); + assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message + assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>"); + + + // test llama_chat_format_single for user message + printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); chat2.push_back({"system", "You are a helpful assistant"}); chat2.push_back({"user", "Hello"}); chat2.push_back({"assistant", "I am assistant"}); @@ -147,12 +164,13 @@ int main(void) { auto fmt_single = [&](std::string tmpl) { auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true); - std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n"; + printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); + printf("-------------------------\n", output.c_str()); return output; }; assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n"); assert(fmt_single("llama2") == "[INST] How are you [/INST]"); - assert(fmt_single("gemma") == "\nuser\nHow are you\nmodel\n"); + assert(fmt_single("gemma") == "\nuser\nHow are you\nmodel\n"); assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); return 0;