diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index 2eca8a9fb..ff4ad6994 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -15,13 +15,11 @@ using json = nlohmann::json; inline static json oaicompat_completion_params_parse( + const struct llama_model * model, const json &body, /* openai api json semantics */ const std::string &chat_template) { json llama_params; - std::string formatted_prompt = chat_template == "chatml" - ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...) - : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...) llama_params["__oaicompat"] = true; @@ -34,7 +32,7 @@ inline static json oaicompat_completion_params_parse( // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = formatted_prompt; + llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 23482ed95..c7821eca6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -37,7 +37,7 @@ struct server_params std::string hostname = "127.0.0.1"; std::vector api_keys; std::string public_path = "examples/server/public"; - std::string chat_template = "chatml"; + std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; @@ -1937,8 +1937,9 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); - printf(" --chat-template FORMAT_NAME"); - printf(" set chat template, possible value is: llama2, chatml (default %s)", sparams.chat_template.c_str()); + printf(" --chat-template JINJA_TEMPLATE\n"); + printf(" set custom jinja chat template (default: template taken from model's metadata)\n"); + printf(" Note: only commonly used templates are accepted, since we don't have jinja parser\n"); printf("\n"); } @@ -2389,13 +2390,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - std::string value(argv[i]); - if (value != "chatml" && value != "llama2") { - fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str()); + if (!verify_custom_template(argv[i])) { + fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); + fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); invalid_param = true; break; } - sparams.chat_template = value; + sparams.chat_template = argv[i]; } else if (arg == "--override-kv") { @@ -2913,7 +2914,7 @@ int main(int argc, char **argv) if (!validate_api_key(req, res)) { return; } - json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template); + json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 0ee670dba..e954fb0ef 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -167,50 +167,47 @@ static T json_value(const json &body, const std::string &key, const T &default_v : default_value; } -inline std::string format_llama2(std::vector messages) -{ - std::ostringstream output; - bool is_inside_turn = false; - - for (auto it = messages.begin(); it != messages.end(); ++it) { - if (!is_inside_turn) { - output << "[INST] "; - } - std::string role = json_value(*it, "role", std::string("user")); - std::string content = json_value(*it, "content", std::string("")); - if (role == "system") { - output << "<>\n" << content << "\n<>\n\n"; - is_inside_turn = true; - } else if (role == "user") { - output << content << " [/INST]"; - is_inside_turn = true; - } else { - output << " " << content << " "; - is_inside_turn = false; - } - } - - LOG_VERBOSE("format_llama2", {{"text", output.str()}}); - - return output.str(); +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +inline bool verify_custom_template(const std::string & tmpl) { + llama_chat_message chat[] = {{"user", "test"}}; + std::vector buf(1); + int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size()); + return res >= 0; } -inline std::string format_chatml(std::vector messages) +// Format given chat. If tmpl is empty, we take the template from model metadata +inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { - std::ostringstream chatml_msgs; + size_t alloc_size = 0; + // vector holding all allocated string to be passed to llama_chat_apply_template + std::vector str(messages.size() * 2); + std::vector chat(messages.size()); - for (auto it = messages.begin(); it != messages.end(); ++it) { - chatml_msgs << "<|im_start|>" - << json_value(*it, "role", std::string("user")) << '\n'; - chatml_msgs << json_value(*it, "content", std::string("")) - << "<|im_end|>\n"; + for (size_t i = 0; i < messages.size(); ++i) { + auto &curr_msg = messages[i]; + str[i*2 + 0] = json_value(curr_msg, "role", std::string("")); + str[i*2 + 1] = json_value(curr_msg, "content", std::string("")); + alloc_size += str[i*2 + 1].length(); + chat[i].role = str[i*2 + 0].c_str(); + chat[i].content = str[i*2 + 1].c_str(); } - chatml_msgs << "<|im_start|>assistant" << '\n'; + const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + std::vector buf(alloc_size * 2); - LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}}); + // run the first time to get the total output length + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); - return chatml_msgs.str(); + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); + } + + std::string formatted_chat(buf.data(), res); + LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); + + return formatted_chat; } // diff --git a/llama.cpp b/llama.cpp index 5de07dfa9..4296eca32 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12602,7 +12602,7 @@ LLAMA_API int32_t llama_chat_apply_template( // load template from model std::vector model_template(2048, 0); // longest known template is about 1200 bytes std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size()); + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); if (res < 0) { // worst case: there is no information about template, we will use chatml by default curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal