Add --jinja to llama-run

This commit is contained in:
ochafik 2025-01-13 22:07:49 +00:00
parent 18f257bf1a
commit 8dd4f334a4
2 changed files with 35 additions and 11 deletions

View File

@ -1830,8 +1830,10 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model *
std::string default_template_src = chat_template_override; std::string default_template_src = chat_template_override;
std::string tool_use_template_src = chat_template_override; std::string tool_use_template_src = chat_template_override;
if (chat_template_override.empty()) { if (chat_template_override.empty()) {
default_template_src = llama_model_chat_template(model, /* name */ nullptr); auto str = llama_model_chat_template(model, /* name */ nullptr);
tool_use_template_src = llama_model_chat_template(model, /* name */ "tool_use"); if (str) default_template_src = str;
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) tool_use_template_src = str;
} }
if (default_template_src.empty() || default_template_src == "chatml") { if (default_template_src.empty() || default_template_src == "chatml") {
if (!tool_use_template_src.empty()) { if (!tool_use_template_src.empty()) {

View File

@ -103,6 +103,7 @@ class Opt {
llama_model_params model_params; llama_model_params model_params;
std::string model_; std::string model_;
std::string user; std::string user;
bool use_jinja = false;
int context_size = -1, ngl = -1; int context_size = -1, ngl = -1;
float temperature = -1; float temperature = -1;
bool verbose = false; bool verbose = false;
@ -154,6 +155,8 @@ class Opt {
} else if (options_parsing && } else if (options_parsing &&
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) { (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
verbose = true; verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true; help = true;
return 0; return 0;
@ -711,13 +714,31 @@ static void add_message(const char * role, const std::string & text, LlamaData &
} }
// Function to apply the chat template and resize `formatted` if needed // Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) { static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) {
json messages = json::array();
for (const auto & msg : llama_data.messages) {
messages.push_back({
{"role", msg.role},
{ "content", msg.content}
});
}
try {
auto result = tmpl.apply(messages, /* tools= */ json(), append);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return llama_data.fmtted.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
}
int result = llama_chat_apply_template( int result = llama_chat_apply_template(
llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append, tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) { if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result); llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size()); llama_data.fmtted.size());
} }
@ -847,8 +868,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
} }
// Helper function to apply the chat template and handle errors // Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) { static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(llama_data, append); const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
if (new_len < 0) { if (new_len < 0) {
printe("failed to apply the chat template\n"); printe("failed to apply the chat template\n");
return -1; return -1;
@ -911,9 +932,10 @@ static int get_user_input(std::string & user_input, const std::string & user) {
} }
// Main chat loop function // Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user) { static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
int prev_len = 0; int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = llama_chat_templates_from_model(llama_data.model.get(), "");
static const bool stdout_a_terminal = is_stdout_a_terminal(); static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) { while (true) {
// Get user input // Get user input
@ -924,7 +946,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
add_message("user", user.empty() ? user_input : user, llama_data); add_message("user", user.empty() ? user_input : user, llama_data);
int new_len; int new_len;
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) { if (apply_chat_template_with_error_handling(chat_templates.default_template, llama_data, true, new_len, use_jinja) < 0) {
return 1; return 1;
} }
@ -939,7 +961,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
} }
add_message("assistant", response, llama_data); add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) { if (apply_chat_template_with_error_handling(chat_templates.default_template, llama_data, false, prev_len, use_jinja) < 0) {
return 1; return 1;
} }
} }
@ -999,7 +1021,7 @@ int main(int argc, const char ** argv) {
return 1; return 1;
} }
if (chat_loop(llama_data, opt.user)) { if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
return 1; return 1;
} }