diff --git a/.gitignore b/.gitignore index 545c28726..5d7c5479e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ .DS_Store .build/ .cache/ +.ccls-cache/ .direnv/ .envrc .swiftpm diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c163c7f8e..47ae0d558 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -149,6 +149,7 @@ struct task_server { task_type type; json data; bool infill_mode = false; + bool embedding_mode = false; }; struct task_result { @@ -371,6 +372,7 @@ struct llama_client_slot std::vector generated_token_probs; bool infill = false; + bool embedding = false; bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -1244,13 +1246,14 @@ struct llama_server_context queue_results.push_back(res); } - int request_completion(json data, bool infill) + int request_completion(json data, bool infill, bool embedding) { std::lock_guard lock(mutex_tasks); task_server task; task.id = id_gen++; task.data = data; task.infill_mode = infill; + task.embedding_mode = embedding; task.type = COMPLETION_TASK; queue_tasks.push_back(task); return task.id; @@ -1376,7 +1379,7 @@ struct llama_server_context { LOG_TEE("slot unavailable\n"); // send error result - send_error(task.id, "slot unavaliable"); + send_error(task.id, "slot unavailable"); return; } @@ -1388,6 +1391,7 @@ struct llama_server_context slot->reset(); slot->infill = task.infill_mode; + slot->embedding = task.embedding_mode; slot->task_id = task.id; if (!launch_slot_with_data(slot, task.data)) @@ -1695,7 +1699,7 @@ struct llama_server_context } // prompt evaluated for embedding - if (params.embedding) + if (slot.embedding) { send_embedding(slot); slot.release(); @@ -2274,7 +2278,7 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body); - const int task_id = llama.request_completion(data, false); + const int task_id = llama.request_completion(data, false, false); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.next_result(task_id); @@ -2329,7 +2333,7 @@ int main(int argc, char **argv) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) { json data = json::parse(req.body); - const int task_id = llama.request_completion(data, true); + const int task_id = llama.request_completion(data, true, false); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.next_result(task_id); @@ -2433,7 +2437,7 @@ int main(int argc, char **argv) { prompt = ""; } - const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false); + const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true); task_result result = llama.next_result(task_id); return res.set_content(result.result_json.dump(), "application/json"); });