mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
Add openai-compatible POST /v1/chat/completions API endpoint to server example
This commit is contained in:
parent
8e672efe63
commit
a0a08eedb6
@ -29,6 +29,8 @@
|
||||
#define SERVER_VERBOSE 1
|
||||
#endif
|
||||
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
struct server_params
|
||||
@ -63,6 +65,10 @@ static bool server_verbose = false;
|
||||
// base64 utils (TODO: move to common in the future)
|
||||
//
|
||||
|
||||
nlohmann::json oaicompat_completion_params_parse(
|
||||
const nlohmann::json &body);
|
||||
std::string format_chatml(std::vector<json> messages);
|
||||
|
||||
static const std::string base64_chars =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
@ -377,6 +383,9 @@ struct llama_client_slot
|
||||
bool stopped_eos = false;
|
||||
bool stopped_word = false;
|
||||
bool stopped_limit = false;
|
||||
|
||||
bool oaicompat = false;
|
||||
std::string oaicompat_model = "";
|
||||
|
||||
std::string stopping_word;
|
||||
|
||||
@ -676,7 +685,16 @@ struct llama_server_context
|
||||
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
|
||||
slot_params default_params;
|
||||
llama_sampling_params default_sparams;
|
||||
|
||||
|
||||
if (data.count("__oaicompat") != 0) {
|
||||
slot->oaicompat = true;
|
||||
slot->oaicompat_model =
|
||||
json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||
} else {
|
||||
slot->oaicompat = false;
|
||||
slot->oaicompat_model = "";
|
||||
}
|
||||
|
||||
slot->params.stream = json_value(data, "stream", false);
|
||||
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
||||
@ -1169,6 +1187,12 @@ struct llama_server_context
|
||||
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
|
||||
}
|
||||
|
||||
if (slot.oaicompat)
|
||||
{
|
||||
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res.result_json["model"] = slot.oaicompat_model;
|
||||
}
|
||||
|
||||
queue_results.push_back(res);
|
||||
}
|
||||
|
||||
@ -1216,6 +1240,12 @@ struct llama_server_context
|
||||
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
||||
}
|
||||
|
||||
if (slot.oaicompat)
|
||||
{
|
||||
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
|
||||
res.result_json["model"] = slot.oaicompat_model;
|
||||
}
|
||||
|
||||
queue_results.push_back(res);
|
||||
}
|
||||
|
||||
@ -2178,6 +2208,249 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static std::string random_string() {
|
||||
std::string str(
|
||||
"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 generator(rd());
|
||||
|
||||
std::shuffle(str.begin(), str.end(), generator);
|
||||
|
||||
return str.substr(0, 32); // assumes 32 < number of characters in str
|
||||
}
|
||||
|
||||
static std::string gen_chatcmplid() {
|
||||
std::stringstream chatcmplid;
|
||||
chatcmplid << "chatcmpl-" << random_string();
|
||||
return chatcmplid.str();
|
||||
}
|
||||
|
||||
std::string format_chatml(std::vector<json> messages) {
|
||||
|
||||
std::ostringstream chatml_msgs;
|
||||
|
||||
// iterate the array
|
||||
for (auto it = messages.begin(); it != messages.end(); ++it) {
|
||||
chatml_msgs << "<|im_start|>"
|
||||
<< json_value(*it, "role", std::string("user")) << '\n';
|
||||
chatml_msgs << json_value(*it, "content", std::string(""))
|
||||
<< "<|im_end|>\n";
|
||||
}
|
||||
|
||||
chatml_msgs << "<|im_start|>assistant" << '\n';
|
||||
|
||||
return chatml_msgs.str();
|
||||
}
|
||||
|
||||
/* llama.cpp completion api semantics */
|
||||
nlohmann::json oaicompat_completion_params_parse(
|
||||
const nlohmann::json &body /* openai api json semantics */) {
|
||||
nlohmann::json llama_params;
|
||||
|
||||
llama_params["__oaicompat"] = true;
|
||||
|
||||
// Map OpenAI parameters to llama.cpp parameters
|
||||
llama_params["prompt"] = format_chatml(
|
||||
body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
|
||||
llama_params["temperature"] =
|
||||
json_value(body, "temperature", 0.8); // Default to 0.8 if not provided
|
||||
llama_params["top_k"] =
|
||||
json_value(body, "max_tokens", 40); // Default to 40 if not provided
|
||||
llama_params["top_p"] =
|
||||
json_value(body, "top_p", 0.95); // Default to 0.95 if not provided
|
||||
llama_params["n_predict"] =
|
||||
json_value(body, "max_tokens", -1); // Default to -1 if not provided
|
||||
llama_params["logit_bias"] = json_value(
|
||||
body, "logit_bias",
|
||||
nlohmann::json::object()); // Default to empty object if not provided
|
||||
llama_params["frequency_penalty"] = json_value(
|
||||
body, "frequency_penalty", 0.0); // Default to 0.0 if not provided
|
||||
llama_params["presence_penalty"] = json_value(
|
||||
body, "presence_penalty", 0.0); // Default to 0.0 if not provided
|
||||
llama_params["seed"] = json_value(body, "seed", 0);
|
||||
llama_params["stream"] =
|
||||
json_value(body, "stream", false); // Default to 0 if not provided
|
||||
llama_params["mirostat"] =
|
||||
json_value(body, "mirostat", false); // Default to false if not provided
|
||||
llama_params["mirostat_tau"] =
|
||||
json_value(body, "mirostat_tau", 0.0); // Default to 0.0 if not provided
|
||||
llama_params["mirostat_eta"] =
|
||||
json_value(body, "mirostat_eta", 0.0); // Default to 0.0 if not provided
|
||||
llama_params["penalize_nl"] = json_value(
|
||||
body, "penalize_nl", false); // Default to false if not provided
|
||||
llama_params["typical_p"] =
|
||||
json_value(body, "typical_p", 0.0); // Default to 0.0 if not provided
|
||||
llama_params["repeat_last_n"] =
|
||||
json_value(body, "repeat_last_n", 0); // Default to 0 if not provided
|
||||
llama_params["ignore_eos"] =
|
||||
json_value(body, "ignore_eos", false); // Default to false if not provided
|
||||
llama_params["tfs_z"] =
|
||||
json_value(body, "tfs_z", 0.0); // Default to 0.0 if not provided
|
||||
if (llama_params.count("grammar") != 0) {
|
||||
llama_params["grammar"] = json_value(
|
||||
body, "grammar",
|
||||
nlohmann::json::object()); // Default to empty object if not provided
|
||||
}
|
||||
|
||||
// Handle 'stop' field
|
||||
if (body["stop"].is_null()) {
|
||||
llama_params["stop"] = json::array({});
|
||||
} else if (body["stop"].is_string()) {
|
||||
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
|
||||
} else {
|
||||
llama_params["stop"] = json_value(
|
||||
body, "stop",
|
||||
json::array()); // Default to empty array if not provided
|
||||
}
|
||||
|
||||
llama_params["stop"].push_back("<|im_end|>");
|
||||
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
static json format_final_response_oaicompat(json request, task_result response,
|
||||
bool streaming = false) {
|
||||
|
||||
json result = response.result_json;
|
||||
|
||||
bool stopped_word = result.count("stopped_word") != 0;
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
|
||||
std::string finish_reason = "length";
|
||||
if (stopped_word || stopped_eos) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json choices =
|
||||
streaming ? json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}})
|
||||
: json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", content},
|
||||
{"role", "assistant"}}}}});
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
json res =
|
||||
json{{"choices", choices},
|
||||
{"created", t},
|
||||
{"model",
|
||||
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
|
||||
{"usage",
|
||||
json{{"completion_tokens", num_tokens_predicted},
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
|
||||
{"id", gen_chatcmplid()}};
|
||||
|
||||
if (server_verbose) {
|
||||
res["__verbose"] = result;
|
||||
}
|
||||
|
||||
if (result.contains("completion_probabilities")) {
|
||||
res["completion_probabilities"] =
|
||||
json_value(result, "completion_probabilities", json::array());
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static std::vector<json> format_partial_response_oaicompat(task_result response) {
|
||||
json result = response.result_json;
|
||||
|
||||
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
||||
return std::vector<json>({response.result_json});
|
||||
}
|
||||
|
||||
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
|
||||
std::string modelname =
|
||||
json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||
|
||||
bool stopped_word = json_value(result, "stopped_word", false);
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
bool stopped_limit = json_value(result, "stopped_limit", false);
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
|
||||
std::string finish_reason = "";
|
||||
if (stopped_word || stopped_eos) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
if (stopped_limit) {
|
||||
finish_reason = "length";
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
json choices;
|
||||
|
||||
if (!finish_reason.empty()) {
|
||||
choices = json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}});
|
||||
} else {
|
||||
if (first) {
|
||||
if (content.empty()) {
|
||||
choices = json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"role", "assistant"}}}}});
|
||||
} else {
|
||||
// We have to send this as two updates to conform to openai behavior
|
||||
json initial_ret = json{{"choices",
|
||||
json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{
|
||||
{"role", "assistant"}
|
||||
}}}})},
|
||||
{"created", t},
|
||||
{"id", gen_chatcmplid()},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
json second_ret = json{{"choices",
|
||||
json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{
|
||||
{"content", content}}}}})},
|
||||
{"created", t},
|
||||
{"id", gen_chatcmplid()},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
return std::vector<json>({initial_ret, second_ret});
|
||||
}
|
||||
} else {
|
||||
// Some idosyncrasy in task processing logic makes several trailing calls
|
||||
// with empty content, we ignore these at the calee site.
|
||||
if (content.empty()) {
|
||||
return std::vector<json>({json::object()});
|
||||
}
|
||||
choices = json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta",
|
||||
json{
|
||||
{"content", content},
|
||||
}},
|
||||
}});
|
||||
}
|
||||
}
|
||||
|
||||
json ret = json{{"choices", choices},
|
||||
{"created", t},
|
||||
{"id", gen_chatcmplid()},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
return std::vector<json>({ret});
|
||||
}
|
||||
|
||||
static json format_partial_response(
|
||||
llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector<completion_token_output> &probs
|
||||
) {
|
||||
@ -2396,6 +2669,78 @@ int main(int argc, char **argv)
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req,
|
||||
httplib::Response &res) {
|
||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||
|
||||
const int task_id = llama.request_completion(data, false, false);
|
||||
if (!json_value(data, "stream", false)) {
|
||||
std::string completion_text;
|
||||
task_result result = llama.next_result(task_id);
|
||||
|
||||
if (!result.error && result.stop) {
|
||||
json oaicompat_result = format_final_response_oaicompat(data, result);
|
||||
|
||||
res.set_content(oaicompat_result.dump(-1, ' ', false,
|
||||
json::error_handler_t::replace),
|
||||
"application/json");
|
||||
} else {
|
||||
res.status = 500;
|
||||
res.set_content(result.result_json["content"], "text/plain");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_id, &llama](size_t,
|
||||
httplib::DataSink &sink) {
|
||||
while (true) {
|
||||
task_result llama_result = llama.next_result(task_id);
|
||||
if (!llama_result.error) {
|
||||
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
|
||||
|
||||
for (auto it = result_array.begin(); it != result_array.end(); ++it)
|
||||
{
|
||||
if (!it->empty()) {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
it->dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
LOG_VERBOSE("data stream", {{"to_send", str}});
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (llama_result.stop) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const std::string str =
|
||||
"error: " +
|
||||
llama_result.result_json.dump(-1, ' ', false,
|
||||
json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
LOG_VERBOSE("data stream", {{"to_send", str}});
|
||||
if (!sink.write(str.c_str(), str.size())) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
sink.done();
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [task_id, &llama](bool) {
|
||||
// cancel
|
||||
llama.request_cancel(task_id);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream",
|
||||
chunked_content_provider, on_complete);
|
||||
}
|
||||
});
|
||||
|
||||
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
json data = json::parse(req.body);
|
||||
|
Loading…
Reference in New Issue
Block a user