mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-04 23:52:32 +01:00
Refactor common_chat_* functions to accept minja template + use_jinja option
This commit is contained in:
parent
3ed670b6dd
commit
b75d0622e4
@ -74,6 +74,15 @@
|
|||||||
#endif
|
#endif
|
||||||
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||||
|
|
||||||
|
const char * LLAMA_CHATML_TEMPLATE = R"(
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{{- "<|im_start|>assistant\n" -}}
|
||||||
|
{%- endif -%}
|
||||||
|
)";
|
||||||
|
|
||||||
//
|
//
|
||||||
// CURL utils
|
// CURL utils
|
||||||
//
|
//
|
||||||
@ -1748,56 +1757,56 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|||||||
return res >= 0;
|
return res >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_apply_template(const struct llama_model * model,
|
std::string common_chat_apply_template(
|
||||||
const std::string & tmpl,
|
const llama_chat_template & tmpl,
|
||||||
const std::vector<common_chat_msg> & msgs,
|
const std::vector<common_chat_msg> & msgs,
|
||||||
bool add_ass) {
|
bool add_ass,
|
||||||
|
bool use_jinja) {
|
||||||
|
if (use_jinja) {
|
||||||
|
auto messages = json::array();
|
||||||
|
for (const auto & msg : msgs) {
|
||||||
|
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||||
|
}
|
||||||
|
return tmpl.apply(messages, /* tools= */ json(), add_ass);
|
||||||
|
}
|
||||||
|
|
||||||
int alloc_size = 0;
|
int alloc_size = 0;
|
||||||
bool fallback = false; // indicate if we must fallback to default chatml
|
|
||||||
std::vector<llama_chat_message> chat;
|
std::vector<llama_chat_message> chat;
|
||||||
for (const auto & msg : msgs) {
|
for (const auto & msg : msgs) {
|
||||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model, /* name */ nullptr) : tmpl.c_str();
|
|
||||||
std::vector<char> buf(alloc_size);
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
||||||
// run the first time to get the total output length
|
// run the first time to get the total output length
|
||||||
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
|
|
||||||
// error: chat template is not supported
|
// error: chat template is not supported
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
if (ptr_tmpl != nullptr) {
|
// if the custom "tmpl" is not supported, we throw an error
|
||||||
// if the custom "tmpl" is not supported, we throw an error
|
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
throw std::runtime_error("this custom template is not supported");
|
||||||
throw std::runtime_error("this custom template is not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the built-in template is not supported, we default to chatml
|
|
||||||
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
fallback = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it turns out that our buffer is too small, we resize it
|
// if it turns out that our buffer is too small, we resize it
|
||||||
if ((size_t) res > buf.size()) {
|
if ((size_t) res > buf.size()) {
|
||||||
buf.resize(res);
|
buf.resize(res);
|
||||||
res = llama_chat_apply_template(
|
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
fallback ? "chatml" : ptr_tmpl,
|
|
||||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string formatted_chat(buf.data(), res);
|
std::string formatted_chat(buf.data(), res);
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_format_single(const struct llama_model * model,
|
std::string common_chat_format_single(
|
||||||
const std::string & tmpl,
|
const llama_chat_template & tmpl,
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass) {
|
bool add_ass,
|
||||||
|
bool use_jinja) {
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
|
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
|
||||||
std::vector<common_chat_msg> chat_new(past_msg);
|
std::vector<common_chat_msg> chat_new(past_msg);
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||||
@ -1805,29 +1814,20 @@ std::string common_chat_format_single(const struct llama_model * model,
|
|||||||
};
|
};
|
||||||
// format chat with new_msg
|
// format chat with new_msg
|
||||||
chat_new.push_back(new_msg);
|
chat_new.push_back(new_msg);
|
||||||
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
|
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
|
||||||
// get the diff part
|
// get the diff part
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) {
|
std::string common_chat_format_example(const llama_chat_template & tmpl, bool use_jinja) {
|
||||||
std::vector<common_chat_msg> msgs = {
|
std::vector<common_chat_msg> msgs = {
|
||||||
{"system", "You are a helpful assistant"},
|
{"system", "You are a helpful assistant"},
|
||||||
{"user", "Hello"},
|
{"user", "Hello"},
|
||||||
{"assistant", "Hi there"},
|
{"assistant", "Hi there"},
|
||||||
{"user", "How are you?"},
|
{"user", "How are you?"},
|
||||||
};
|
};
|
||||||
const auto add_generation_prompt = true;
|
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
|
||||||
if (use_jinja) {
|
|
||||||
auto messages = json::array();
|
|
||||||
for (const auto & msg : msgs) {
|
|
||||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
|
||||||
}
|
|
||||||
return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt);
|
|
||||||
} else {
|
|
||||||
return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
|
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
|
||||||
@ -1847,14 +1847,7 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model *
|
|||||||
if (!tool_use_template_src.empty()) {
|
if (!tool_use_template_src.empty()) {
|
||||||
default_template_src = tool_use_template_src;
|
default_template_src = tool_use_template_src;
|
||||||
} else {
|
} else {
|
||||||
default_template_src = R"(
|
default_template_src = LLAMA_CHATML_TEMPLATE;
|
||||||
{%- for message in messages -%}
|
|
||||||
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if add_generation_prompt -%}
|
|
||||||
{{- "<|im_start|>assistant\n" -}}
|
|
||||||
{%- endif -%}
|
|
||||||
)";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
|
@ -26,6 +26,8 @@
|
|||||||
|
|
||||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||||
|
|
||||||
|
extern const char * LLAMA_CHATML_TEMPLATE;
|
||||||
|
|
||||||
struct common_adapter_lora_info {
|
struct common_adapter_lora_info {
|
||||||
std::string path;
|
std::string path;
|
||||||
float scale;
|
float scale;
|
||||||
@ -602,29 +604,32 @@ struct common_chat_msg {
|
|||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||||
|
|
||||||
|
typedef minja::chat_template llama_chat_template;
|
||||||
|
|
||||||
// CPP wrapper for llama_chat_apply_template
|
// CPP wrapper for llama_chat_apply_template
|
||||||
// If the built-in template is not supported, we default to chatml
|
// If the built-in template is not supported, we default to chatml
|
||||||
// If the custom "tmpl" is not supported, we throw an error
|
// If the custom "tmpl" is not supported, we throw an error
|
||||||
std::string common_chat_apply_template(const struct llama_model * model,
|
std::string common_chat_apply_template(
|
||||||
const std::string & tmpl,
|
const llama_chat_template & tmpl,
|
||||||
const std::vector<common_chat_msg> & chat,
|
const std::vector<common_chat_msg> & chat,
|
||||||
bool add_ass);
|
bool add_ass,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
std::string common_chat_format_single(const struct llama_model * model,
|
std::string common_chat_format_single(
|
||||||
const std::string & tmpl,
|
const llama_chat_template & tmpl,
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass);
|
bool add_ass,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
// Returns an example of formatted chat
|
||||||
std::string common_chat_format_example(const struct llama_model * model,
|
std::string common_chat_format_example(
|
||||||
const minja::chat_template & tmpl, bool use_jinja);
|
const llama_chat_template & tmpl, bool use_jinja);
|
||||||
|
|
||||||
|
|
||||||
struct llama_chat_templates {
|
struct llama_chat_templates {
|
||||||
minja::chat_template default_template;
|
llama_chat_template default_template;
|
||||||
std::optional<minja::chat_template> tool_use_template;
|
std::optional<llama_chat_template> tool_use_template;
|
||||||
};
|
};
|
||||||
|
|
||||||
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
||||||
|
@ -84,14 +84,6 @@ static void sigint_handler(int signo) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
|
|
||||||
common_chat_msg new_msg{role, content};
|
|
||||||
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
|
||||||
chat_msgs.push_back({role, content});
|
|
||||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
|
||||||
return formatted;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
g_params = ¶ms;
|
g_params = ¶ms;
|
||||||
@ -226,7 +218,7 @@ int main(int argc, char ** argv) {
|
|||||||
// print chat template example in conversation mode
|
// print chat template example in conversation mode
|
||||||
if (params.conversation_mode) {
|
if (params.conversation_mode) {
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str());
|
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.default_template, params.use_jinja).c_str());
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||||
}
|
}
|
||||||
@ -270,10 +262,18 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
|
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
||||||
|
common_chat_msg new_msg{role, content};
|
||||||
|
auto formatted = common_chat_format_single(chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
||||||
|
chat_msgs.push_back({role, content});
|
||||||
|
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||||
|
return formatted;
|
||||||
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
auto prompt = (params.conversation_mode && params.enable_chat_template)
|
auto prompt = (params.conversation_mode && params.enable_chat_template)
|
||||||
// format the system prompt in conversation mode (fallback to default if empty)
|
// format the system prompt in conversation mode (fallback to default if empty)
|
||||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
|
? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
|
||||||
// otherwise use the prompt as is
|
// otherwise use the prompt as is
|
||||||
: params.prompt;
|
: params.prompt;
|
||||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
@ -780,7 +780,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
|
chat_add_and_format("assistant", assistant_ss.str());
|
||||||
}
|
}
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
@ -845,7 +845,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
bool format_chat = params.conversation_mode && params.enable_chat_template;
|
bool format_chat = params.conversation_mode && params.enable_chat_template;
|
||||||
std::string user_inp = format_chat
|
std::string user_inp = format_chat
|
||||||
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
? chat_add_and_format("user", std::move(buffer))
|
||||||
: std::move(buffer);
|
: std::move(buffer);
|
||||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||||
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
||||||
|
@ -714,7 +714,7 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Function to apply the chat template and resize `formatted` if needed
|
// Function to apply the chat template and resize `formatted` if needed
|
||||||
static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
static int apply_chat_template(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
json messages = json::array();
|
json messages = json::array();
|
||||||
for (const auto & msg : llama_data.messages) {
|
for (const auto & msg : llama_data.messages) {
|
||||||
@ -868,7 +868,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to apply the chat template and handle errors
|
// Helper function to apply the chat template and handle errors
|
||||||
static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
static int apply_chat_template_with_error_handling(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||||
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
|
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
|
||||||
if (new_len < 0) {
|
if (new_len < 0) {
|
||||||
printe("failed to apply the chat template\n");
|
printe("failed to apply the chat template\n");
|
||||||
|
@ -3869,7 +3869,7 @@ int main(int argc, char ** argv) {
|
|||||||
auto body = json::parse(req.body);
|
auto body = json::parse(req.body);
|
||||||
const auto & templates = get_chat_templates();
|
const auto & templates = get_chat_templates();
|
||||||
const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template;
|
const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template;
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja);
|
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||||
|
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
@ -4288,7 +4288,7 @@ int main(int argc, char ** argv) {
|
|||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||||
get_chat_templates().default_template.source().c_str(),
|
get_chat_templates().default_template.source().c_str(),
|
||||||
common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str());
|
common_chat_format_example(get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str());
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||||
|
@ -351,7 +351,7 @@ static llama_tokens format_infill(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::string format_chat(const llama_chat_template & tmpl, const std::vector<json> & messages) {
|
||||||
std::vector<common_chat_msg> chat;
|
std::vector<common_chat_msg> chat;
|
||||||
|
|
||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
@ -379,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
|||||||
chat.push_back({role, content});
|
chat.push_back({role, content});
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
|
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
|
||||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||||
|
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
@ -579,9 +579,8 @@ static json oaicompat_completion_params_parse(const json & body) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static json oaicompat_completion_params_parse(
|
static json oaicompat_completion_params_parse(
|
||||||
const struct llama_model * model,
|
|
||||||
const json & body, /* openai api json semantics */
|
const json & body, /* openai api json semantics */
|
||||||
const minja::chat_template & tmpl,
|
const llama_chat_template & tmpl,
|
||||||
bool use_jinja)
|
bool use_jinja)
|
||||||
{
|
{
|
||||||
json llama_params;
|
json llama_params;
|
||||||
@ -622,7 +621,7 @@ static json oaicompat_completion_params_parse(
|
|||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||||
} else {
|
} else {
|
||||||
llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages"));
|
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle "n" field
|
// Handle "n" field
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "chat-template.hpp"
|
#include "chat-template.hpp"
|
||||||
|
#include "llama-chat.h"
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
std::vector<llama_chat_message> conversation {
|
std::vector<llama_chat_message> conversation {
|
||||||
@ -319,9 +320,10 @@ int main(void) {
|
|||||||
std::vector<common_chat_msg> chat2;
|
std::vector<common_chat_msg> chat2;
|
||||||
common_chat_msg sys_msg{"system", "You are a helpful assistant"};
|
common_chat_msg sys_msg{"system", "You are a helpful assistant"};
|
||||||
|
|
||||||
auto fmt_sys = [&](std::string tmpl) {
|
auto fmt_sys = [&](std::string tmpl_str) {
|
||||||
auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
|
minja::chat_template tmpl(tmpl_str, "", "");
|
||||||
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
|
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
|
||||||
|
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
return output;
|
return output;
|
||||||
};
|
};
|
||||||
@ -345,9 +347,10 @@ int main(void) {
|
|||||||
chat2.push_back({"assistant", "I am assistant"});
|
chat2.push_back({"assistant", "I am assistant"});
|
||||||
common_chat_msg new_msg{"user", "How are you"};
|
common_chat_msg new_msg{"user", "How are you"};
|
||||||
|
|
||||||
auto fmt_single = [&](std::string tmpl) {
|
auto fmt_single = [&](std::string tmpl_str) {
|
||||||
auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
|
minja::chat_template tmpl(tmpl_str, "", "");
|
||||||
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
|
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
|
||||||
|
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
return output;
|
return output;
|
||||||
};
|
};
|
||||||
@ -362,5 +365,7 @@ int main(void) {
|
|||||||
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
||||||
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
|
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
|
||||||
|
|
||||||
|
assert(llm_chat_detect_template(LLAMA_CHATML_TEMPLATE) == LLM_CHAT_TEMPLATE_CHATML);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user