Server: Use multi-task for embeddings endpoint (#6001)

* use multitask for embd endpoint

* specify types

* remove redundant {"n_predict", 0}
This commit is contained in:
Xuan Son Nguyen 2024-03-13 11:39:11 +01:00 committed by GitHub
parent 306d34be7a
commit 99b71c068f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 50 deletions

View File

@ -2763,6 +2763,7 @@ int main(int argc, char ** argv) {
res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*"); res.set_header("Access-Control-Allow-Headers", "*");
return res.set_content("", "application/json; charset=utf-8");
}); });
svr->set_logger(log_server_request); svr->set_logger(log_server_request);
@ -3371,44 +3372,37 @@ int main(int argc, char ** argv) {
const json body = json::parse(req.body); const json body = json::parse(req.body);
bool is_openai = false; bool is_openai = false;
// an input prompt can string or a list of tokens (integer) // an input prompt can be a string or a list of tokens (integer)
std::vector<json> prompts; json prompt;
if (body.count("input") != 0) { if (body.count("input") != 0) {
is_openai = true; is_openai = true;
if (body["input"].is_array()) { prompt = body["input"];
// support multiple prompts
for (const json & elem : body["input"]) {
prompts.push_back(elem);
}
} else {
// single input prompt
prompts.push_back(body["input"]);
}
} else if (body.count("content") != 0) { } else if (body.count("content") != 0) {
// only support single prompt here // with "content", we only support single prompt
std::string content = body["content"]; prompt = std::vector<std::string>{body["content"]};
prompts.push_back(content);
} else { } else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
// process all prompts // create and queue the task
json responses = json::array(); json responses;
for (auto & prompt : prompts) { {
// TODO @ngxson : maybe support multitask for this endpoint?
// create and queue the task
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task); ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
// get the result // get the result
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
if (!result.error) { if (!result.error) {
// append to the responses if (result.data.count("results")) {
responses.push_back(result.data); // result for multi-task
responses = result.data["results"];
} else {
// result for single task
responses = std::vector<json>{result.data};
}
} else { } else {
// error received, ignore everything else // error received, ignore everything else
res_error(res, result.data); res_error(res, result.data);
@ -3417,24 +3411,19 @@ int main(int argc, char ** argv) {
} }
// write JSON response // write JSON response
json root; json root = is_openai
if (is_openai) { ? format_embeddings_response_oaicompat(body, responses)
json res_oai = json::array(); : responses[0];
int i = 0;
for (auto & elem : responses) {
res_oai.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
}
root = format_embeddings_response_oaicompat(body, res_oai);
} else {
root = responses[0];
}
return res.set_content(root.dump(), "application/json; charset=utf-8"); return res.set_content(root.dump(), "application/json; charset=utf-8");
}; };
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};
// //
// Router // Router
// //
@ -3446,17 +3435,6 @@ int main(int argc, char ** argv) {
} }
// using embedded static files // using embedded static files
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
return res.set_content("", "application/json; charset=utf-8");
});
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));

View File

@ -529,6 +529,16 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
} }
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
json data = json::array();
int i = 0;
for (auto & elem : embeddings) {
data.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});
}
json res = json { json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"}, {"object", "list"},
@ -536,7 +546,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
{"prompt_tokens", 0}, {"prompt_tokens", 0},
{"total_tokens", 0} {"total_tokens", 0}
}}, }},
{"data", embeddings} {"data", data}
}; };
return res; return res;