mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
server : embeddings compatibility for OpenAI (#5190)
This commit is contained in:
parent
14fef85e2d
commit
c82d18e863
@ -206,3 +206,18 @@ inline static std::vector<json> format_partial_response_oaicompat(const task_res
|
|||||||
|
|
||||||
return std::vector<json>({ret});
|
return std::vector<json>({ret});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings)
|
||||||
|
{
|
||||||
|
json res =
|
||||||
|
json{
|
||||||
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
|
{"object", "list"},
|
||||||
|
{"usage",
|
||||||
|
json{{"prompt_tokens", 0},
|
||||||
|
{"total_tokens", 0}}},
|
||||||
|
{"data", embeddings}
|
||||||
|
};
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -2929,6 +2929,66 @@ int main(int argc, char **argv)
|
|||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||||
|
{
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
|
json prompt;
|
||||||
|
if (body.count("input") != 0)
|
||||||
|
{
|
||||||
|
prompt = body["input"];
|
||||||
|
// batch
|
||||||
|
if(prompt.is_array()) {
|
||||||
|
json data = json::array();
|
||||||
|
int i = 0;
|
||||||
|
for (const json &elem : prompt) {
|
||||||
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
|
llama.request_completion(task_id, { {"prompt", elem}, { "n_predict", 0} }, false, true, -1);
|
||||||
|
|
||||||
|
// get the result
|
||||||
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
|
json embedding = json{
|
||||||
|
{"embedding", json_value(result.result_json, "embedding", json::array())},
|
||||||
|
{"index", i++},
|
||||||
|
{"object", "embedding"}
|
||||||
|
};
|
||||||
|
data.push_back(embedding);
|
||||||
|
}
|
||||||
|
json result = format_embeddings_response_oaicompat(body, data);
|
||||||
|
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
prompt = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// create and queue the task
|
||||||
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
|
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}}, false, true, -1);
|
||||||
|
|
||||||
|
// get the result
|
||||||
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
|
json data = json::array({json{
|
||||||
|
{"embedding", json_value(result.result_json, "embedding", json::array())},
|
||||||
|
{"index", 0},
|
||||||
|
{"object", "embedding"}
|
||||||
|
}}
|
||||||
|
);
|
||||||
|
|
||||||
|
json root = format_embeddings_response_oaicompat(body, data);
|
||||||
|
|
||||||
|
// send the result
|
||||||
|
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||||
|
});
|
||||||
|
|
||||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||||
// "Bus error: 10" - this is on macOS, it does not crash on Linux
|
// "Bus error: 10" - this is on macOS, it does not crash on Linux
|
||||||
//std::thread t2([&]()
|
//std::thread t2([&]()
|
||||||
|
Loading…
Reference in New Issue
Block a user