From 72c177c1f6c16693eee319d4ebd4eaab5e630dd2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 11 May 2024 17:28:10 +0200 Subject: [PATCH] fix system prompt handling (#7153) --- examples/server/server.cpp | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 55c1d4129..ceaeb1f76 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -651,9 +651,6 @@ struct server_context { std::string system_prompt; std::vector system_tokens; - std::string name_user; // this should be the antiprompt - std::string name_assistant; - // slots / clients std::vector slots; json default_generation_settings_for_props; @@ -1100,15 +1097,11 @@ struct server_context { system_need_update = false; } - void system_prompt_set(const json & sys_props) { - system_prompt = sys_props.value("prompt", ""); - name_user = sys_props.value("anti_prompt", ""); - name_assistant = sys_props.value("assistant_name", ""); + bool system_prompt_set(const std::string & sys_prompt) { + system_prompt = sys_prompt; LOG_VERBOSE("system prompt process", { {"system_prompt", system_prompt}, - {"name_user", name_user}, - {"name_assistant", name_assistant}, }); // release all slots @@ -1117,6 +1110,7 @@ struct server_context { } system_need_update = true; + return true; } bool process_token(completion_token_output & result, server_slot & slot) { @@ -1536,7 +1530,8 @@ struct server_context { } if (task.data.contains("system_prompt")) { - system_prompt_set(task.data.at("system_prompt")); + std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); + system_prompt_set(sys_prompt); for (server_slot & slot : slots) { slot.n_past = 0; @@ -2920,7 +2915,7 @@ int main(int argc, char ** argv) { server_params_parse(argc, argv, sparams, params); if (!sparams.system_prompt.empty()) { - ctx_server.system_prompt_set(json::parse(sparams.system_prompt)); + ctx_server.system_prompt_set(sparams.system_prompt); } if (params.model_alias == "unknown") { @@ -3409,8 +3404,7 @@ int main(int argc, char ** argv) { const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { - { "user_name", ctx_server.name_user.c_str() }, - { "assistant_name", ctx_server.name_assistant.c_str() }, + { "system_prompt", ctx_server.system_prompt.c_str() }, { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params.n_parallel } };