From a0a08eedb6a23b31d8783bbb91ede583cbe7933a Mon Sep 17 00:00:00 2001 From: kir-gadjello <111190790+kir-gadjello@users.noreply.github.com> Date: Wed, 22 Nov 2023 02:16:38 -0300 Subject: [PATCH] Add openai-compatible POST /v1/chat/completions API endpoint to server example --- examples/server/server.cpp | 347 ++++++++++++++++++++++++++++++++++++- 1 file changed, 346 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1f2c55f2d..25c23d30b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -29,6 +29,8 @@ #define SERVER_VERBOSE 1 #endif +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" + using json = nlohmann::json; struct server_params @@ -63,6 +65,10 @@ static bool server_verbose = false; // base64 utils (TODO: move to common in the future) // +nlohmann::json oaicompat_completion_params_parse( + const nlohmann::json &body); +std::string format_chatml(std::vector messages); + static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz" @@ -377,6 +383,9 @@ struct llama_client_slot bool stopped_eos = false; bool stopped_word = false; bool stopped_limit = false; + + bool oaicompat = false; + std::string oaicompat_model = ""; std::string stopping_word; @@ -676,7 +685,16 @@ struct llama_server_context bool launch_slot_with_data(llama_client_slot* &slot, json data) { slot_params default_params; llama_sampling_params default_sparams; - + + if (data.count("__oaicompat") != 0) { + slot->oaicompat = true; + slot->oaicompat_model = + json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + } else { + slot->oaicompat = false; + slot->oaicompat_model = ""; + } + slot->params.stream = json_value(data, "stream", false); slot->params.cache_prompt = json_value(data, "cache_prompt", false); slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); @@ -1169,6 +1187,12 @@ struct llama_server_context res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); } + if (slot.oaicompat) + { + res.result_json["oaicompat_token_ctr"] = slot.n_decoded; + res.result_json["model"] = slot.oaicompat_model; + } + queue_results.push_back(res); } @@ -1216,6 +1240,12 @@ struct llama_server_context res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); } + if (slot.oaicompat) + { + res.result_json["oaicompat_token_ctr"] = slot.n_decoded; + res.result_json["model"] = slot.oaicompat_model; + } + queue_results.push_back(res); } @@ -2178,6 +2208,249 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } + +static std::string random_string() { + std::string str( + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + + std::random_device rd; + std::mt19937 generator(rd()); + + std::shuffle(str.begin(), str.end(), generator); + + return str.substr(0, 32); // assumes 32 < number of characters in str +} + +static std::string gen_chatcmplid() { + std::stringstream chatcmplid; + chatcmplid << "chatcmpl-" << random_string(); + return chatcmplid.str(); +} + +std::string format_chatml(std::vector messages) { + + std::ostringstream chatml_msgs; + + // iterate the array + for (auto it = messages.begin(); it != messages.end(); ++it) { + chatml_msgs << "<|im_start|>" + << json_value(*it, "role", std::string("user")) << '\n'; + chatml_msgs << json_value(*it, "content", std::string("")) + << "<|im_end|>\n"; + } + + chatml_msgs << "<|im_start|>assistant" << '\n'; + + return chatml_msgs.str(); +} + +/* llama.cpp completion api semantics */ +nlohmann::json oaicompat_completion_params_parse( + const nlohmann::json &body /* openai api json semantics */) { + nlohmann::json llama_params; + + llama_params["__oaicompat"] = true; + + // Map OpenAI parameters to llama.cpp parameters + llama_params["prompt"] = format_chatml( + body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' + llama_params["temperature"] = + json_value(body, "temperature", 0.8); // Default to 0.8 if not provided + llama_params["top_k"] = + json_value(body, "max_tokens", 40); // Default to 40 if not provided + llama_params["top_p"] = + json_value(body, "top_p", 0.95); // Default to 0.95 if not provided + llama_params["n_predict"] = + json_value(body, "max_tokens", -1); // Default to -1 if not provided + llama_params["logit_bias"] = json_value( + body, "logit_bias", + nlohmann::json::object()); // Default to empty object if not provided + llama_params["frequency_penalty"] = json_value( + body, "frequency_penalty", 0.0); // Default to 0.0 if not provided + llama_params["presence_penalty"] = json_value( + body, "presence_penalty", 0.0); // Default to 0.0 if not provided + llama_params["seed"] = json_value(body, "seed", 0); + llama_params["stream"] = + json_value(body, "stream", false); // Default to 0 if not provided + llama_params["mirostat"] = + json_value(body, "mirostat", false); // Default to false if not provided + llama_params["mirostat_tau"] = + json_value(body, "mirostat_tau", 0.0); // Default to 0.0 if not provided + llama_params["mirostat_eta"] = + json_value(body, "mirostat_eta", 0.0); // Default to 0.0 if not provided + llama_params["penalize_nl"] = json_value( + body, "penalize_nl", false); // Default to false if not provided + llama_params["typical_p"] = + json_value(body, "typical_p", 0.0); // Default to 0.0 if not provided + llama_params["repeat_last_n"] = + json_value(body, "repeat_last_n", 0); // Default to 0 if not provided + llama_params["ignore_eos"] = + json_value(body, "ignore_eos", false); // Default to false if not provided + llama_params["tfs_z"] = + json_value(body, "tfs_z", 0.0); // Default to 0.0 if not provided + if (llama_params.count("grammar") != 0) { + llama_params["grammar"] = json_value( + body, "grammar", + nlohmann::json::object()); // Default to empty object if not provided + } + + // Handle 'stop' field + if (body["stop"].is_null()) { + llama_params["stop"] = json::array({}); + } else if (body["stop"].is_string()) { + llama_params["stop"] = json::array({body["stop"].get()}); + } else { + llama_params["stop"] = json_value( + body, "stop", + json::array()); // Default to empty array if not provided + } + + llama_params["stop"].push_back("<|im_end|>"); + + return llama_params; +} + +static json format_final_response_oaicompat(json request, task_result response, + bool streaming = false) { + + json result = response.result_json; + + bool stopped_word = result.count("stopped_word") != 0; + bool stopped_eos = json_value(result, "stopped_eos", false); + int num_tokens_predicted = json_value(result, "tokens_predicted", 0); + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + std::string content = json_value(result, "content", std::string("")); + + std::string finish_reason = "length"; + if (stopped_word || stopped_eos) { + finish_reason = "stop"; + } + + json choices = + streaming ? json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}}}) + : json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"message", json{{"content", content}, + {"role", "assistant"}}}}}); + + std::time_t t = std::time(0); + + json res = + json{{"choices", choices}, + {"created", t}, + {"model", + json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, + {"usage", + json{{"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, + {"id", gen_chatcmplid()}}; + + if (server_verbose) { + res["__verbose"] = result; + } + + if (result.contains("completion_probabilities")) { + res["completion_probabilities"] = + json_value(result, "completion_probabilities", json::array()); + } + + return res; +} + +static std::vector format_partial_response_oaicompat(task_result response) { + json result = response.result_json; + + if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { + return std::vector({response.result_json}); + } + + bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; + std::string modelname = + json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + + bool stopped_word = json_value(result, "stopped_word", false); + bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_limit = json_value(result, "stopped_limit", false); + std::string content = json_value(result, "content", std::string("")); + + std::string finish_reason = ""; + if (stopped_word || stopped_eos) { + finish_reason = "stop"; + } + if (stopped_limit) { + finish_reason = "length"; + } + + std::time_t t = std::time(0); + + json choices; + + if (!finish_reason.empty()) { + choices = json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}}}); + } else { + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", + json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{{"choices", + json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}}}})}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; + return std::vector({initial_ret, second_ret}); + } + } else { + // Some idosyncrasy in task processing logic makes several trailing calls + // with empty content, we ignore these at the calee site. + if (content.empty()) { + return std::vector({json::object()}); + } + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json{ + {"content", content}, + }}, + }}); + } + } + + json ret = json{{"choices", choices}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; + + return std::vector({ret}); +} + static json format_partial_response( llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector &probs ) { @@ -2396,6 +2669,78 @@ int main(int argc, char **argv) } }); + + svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, + httplib::Response &res) { + json data = oaicompat_completion_params_parse(json::parse(req.body)); + + const int task_id = llama.request_completion(data, false, false); + if (!json_value(data, "stream", false)) { + std::string completion_text; + task_result result = llama.next_result(task_id); + + if (!result.error && result.stop) { + json oaicompat_result = format_final_response_oaicompat(data, result); + + res.set_content(oaicompat_result.dump(-1, ' ', false, + json::error_handler_t::replace), + "application/json"); + } else { + res.status = 500; + res.set_content(result.result_json["content"], "text/plain"); + return; + } + } else { + const auto chunked_content_provider = [task_id, &llama](size_t, + httplib::DataSink &sink) { + while (true) { + task_result llama_result = llama.next_result(task_id); + if (!llama_result.error) { + std::vector result_array = format_partial_response_oaicompat( llama_result); + + for (auto it = result_array.begin(); it != result_array.end(); ++it) + { + if (!it->empty()) { + const std::string str = + "data: " + + it->dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + LOG_VERBOSE("data stream", {{"to_send", str}}); + if (!sink.write(str.c_str(), str.size())) { + return false; + } + } + } + if (llama_result.stop) { + break; + } + } else { + const std::string str = + "error: " + + llama_result.result_json.dump(-1, ' ', false, + json::error_handler_t::replace) + + "\n\n"; + LOG_VERBOSE("data stream", {{"to_send", str}}); + if (!sink.write(str.c_str(), str.size())) { + return false; + } + break; + } + } + sink.done(); + return true; + }; + + auto on_complete = [task_id, &llama](bool) { + // cancel + llama.request_cancel(task_id); + }; + + res.set_chunked_content_provider("text/event-stream", + chunked_content_provider, on_complete); + } + }); + svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body);