server : re-enable completion and embedded at the same time (#3876)

This commit is contained in:
Adrian Hesketh 2023-11-01 09:28:28 +00:00 committed by GitHub
parent 71e3718abd
commit ca190bca8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 6 deletions

1
.gitignore vendored
View File

@ -15,6 +15,7 @@
.DS_Store .DS_Store
.build/ .build/
.cache/ .cache/
.ccls-cache/
.direnv/ .direnv/
.envrc .envrc
.swiftpm .swiftpm

View File

@ -149,6 +149,7 @@ struct task_server {
task_type type; task_type type;
json data; json data;
bool infill_mode = false; bool infill_mode = false;
bool embedding_mode = false;
}; };
struct task_result { struct task_result {
@ -371,6 +372,7 @@ struct llama_client_slot
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
bool infill = false; bool infill = false;
bool embedding = false;
bool has_next_token = true; bool has_next_token = true;
bool truncated = false; bool truncated = false;
bool stopped_eos = false; bool stopped_eos = false;
@ -1244,13 +1246,14 @@ struct llama_server_context
queue_results.push_back(res); queue_results.push_back(res);
} }
int request_completion(json data, bool infill) int request_completion(json data, bool infill, bool embedding)
{ {
std::lock_guard<std::mutex> lock(mutex_tasks); std::lock_guard<std::mutex> lock(mutex_tasks);
task_server task; task_server task;
task.id = id_gen++; task.id = id_gen++;
task.data = data; task.data = data;
task.infill_mode = infill; task.infill_mode = infill;
task.embedding_mode = embedding;
task.type = COMPLETION_TASK; task.type = COMPLETION_TASK;
queue_tasks.push_back(task); queue_tasks.push_back(task);
return task.id; return task.id;
@ -1376,7 +1379,7 @@ struct llama_server_context
{ {
LOG_TEE("slot unavailable\n"); LOG_TEE("slot unavailable\n");
// send error result // send error result
send_error(task.id, "slot unavaliable"); send_error(task.id, "slot unavailable");
return; return;
} }
@ -1388,6 +1391,7 @@ struct llama_server_context
slot->reset(); slot->reset();
slot->infill = task.infill_mode; slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->task_id = task.id; slot->task_id = task.id;
if (!launch_slot_with_data(slot, task.data)) if (!launch_slot_with_data(slot, task.data))
@ -1695,7 +1699,7 @@ struct llama_server_context
} }
// prompt evaluated for embedding // prompt evaluated for embedding
if (params.embedding) if (slot.embedding)
{ {
send_embedding(slot); send_embedding(slot);
slot.release(); slot.release();
@ -2274,7 +2278,7 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false); const int task_id = llama.request_completion(data, false, false);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
@ -2329,7 +2333,7 @@ int main(int argc, char **argv)
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.request_completion(data, true); const int task_id = llama.request_completion(data, true, false);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
@ -2433,7 +2437,7 @@ int main(int argc, char **argv)
{ {
prompt = ""; prompt = "";
} }
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false); const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
return res.set_content(result.result_json.dump(), "application/json"); return res.set_content(result.result_json.dump(), "application/json");
}); });