mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 08:00:42 +01:00
Add --jinja to llama-run
This commit is contained in:
parent
18f257bf1a
commit
8dd4f334a4
@ -1830,8 +1830,10 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model *
|
|||||||
std::string default_template_src = chat_template_override;
|
std::string default_template_src = chat_template_override;
|
||||||
std::string tool_use_template_src = chat_template_override;
|
std::string tool_use_template_src = chat_template_override;
|
||||||
if (chat_template_override.empty()) {
|
if (chat_template_override.empty()) {
|
||||||
default_template_src = llama_model_chat_template(model, /* name */ nullptr);
|
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||||
tool_use_template_src = llama_model_chat_template(model, /* name */ "tool_use");
|
if (str) default_template_src = str;
|
||||||
|
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||||
|
if (str) tool_use_template_src = str;
|
||||||
}
|
}
|
||||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||||
if (!tool_use_template_src.empty()) {
|
if (!tool_use_template_src.empty()) {
|
||||||
|
@ -103,6 +103,7 @@ class Opt {
|
|||||||
llama_model_params model_params;
|
llama_model_params model_params;
|
||||||
std::string model_;
|
std::string model_;
|
||||||
std::string user;
|
std::string user;
|
||||||
|
bool use_jinja = false;
|
||||||
int context_size = -1, ngl = -1;
|
int context_size = -1, ngl = -1;
|
||||||
float temperature = -1;
|
float temperature = -1;
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
@ -154,6 +155,8 @@ class Opt {
|
|||||||
} else if (options_parsing &&
|
} else if (options_parsing &&
|
||||||
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
|
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
|
||||||
verbose = true;
|
verbose = true;
|
||||||
|
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
|
||||||
|
use_jinja = true;
|
||||||
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
||||||
help = true;
|
help = true;
|
||||||
return 0;
|
return 0;
|
||||||
@ -711,13 +714,31 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Function to apply the chat template and resize `formatted` if needed
|
// Function to apply the chat template and resize `formatted` if needed
|
||||||
static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||||
|
if (use_jinja) {
|
||||||
|
json messages = json::array();
|
||||||
|
for (const auto & msg : llama_data.messages) {
|
||||||
|
messages.push_back({
|
||||||
|
{"role", msg.role},
|
||||||
|
{ "content", msg.content}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
auto result = tmpl.apply(messages, /* tools= */ json(), append);
|
||||||
|
llama_data.fmtted.resize(result.size() + 1);
|
||||||
|
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||||
|
return llama_data.fmtted.size();
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
printe("failed to render the chat template: %s\n", e.what());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
int result = llama_chat_apply_template(
|
int result = llama_chat_apply_template(
|
||||||
llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append,
|
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
||||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
||||||
llama_data.fmtted.resize(result);
|
llama_data.fmtted.resize(result);
|
||||||
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(),
|
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
|
||||||
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
||||||
llama_data.fmtted.size());
|
llama_data.fmtted.size());
|
||||||
}
|
}
|
||||||
@ -847,8 +868,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to apply the chat template and handle errors
|
// Helper function to apply the chat template and handle errors
|
||||||
static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
|
static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||||
const int new_len = apply_chat_template(llama_data, append);
|
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
|
||||||
if (new_len < 0) {
|
if (new_len < 0) {
|
||||||
printe("failed to apply the chat template\n");
|
printe("failed to apply the chat template\n");
|
||||||
return -1;
|
return -1;
|
||||||
@ -911,9 +932,10 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Main chat loop function
|
// Main chat loop function
|
||||||
static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
|
||||||
int prev_len = 0;
|
int prev_len = 0;
|
||||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||||
|
auto chat_templates = llama_chat_templates_from_model(llama_data.model.get(), "");
|
||||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||||
while (true) {
|
while (true) {
|
||||||
// Get user input
|
// Get user input
|
||||||
@ -924,7 +946,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
|||||||
|
|
||||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||||
int new_len;
|
int new_len;
|
||||||
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
|
if (apply_chat_template_with_error_handling(chat_templates.default_template, llama_data, true, new_len, use_jinja) < 0) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -939,7 +961,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
add_message("assistant", response, llama_data);
|
add_message("assistant", response, llama_data);
|
||||||
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
|
if (apply_chat_template_with_error_handling(chat_templates.default_template, llama_data, false, prev_len, use_jinja) < 0) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -999,7 +1021,7 @@ int main(int argc, const char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chat_loop(llama_data, opt.user)) {
|
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user