mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
Server: reorganize some http logic (#5939)
* refactor static file handler * use set_pre_routing_handler for validate_api_key * merge embedding handlers * correct http verb for endpoints * fix embedding response * fix test case CORS Options * fix code style
This commit is contained in:
parent
e1fa9569ba
commit
950ba1ab84
@ -42,7 +42,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
|
|||||||
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
|
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
|
||||||
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
|
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
|
||||||
- `--port`: Set the port to listen. Default: `8080`.
|
- `--port`: Set the port to listen. Default: `8080`.
|
||||||
- `--path`: path from which to serve static files (default examples/server/public)
|
- `--path`: path from which to serve static files (default: disabled)
|
||||||
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
|
- `--api-key`: Set an api key for request authorization. By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys.
|
||||||
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
|
- `--api-key-file`: path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`'s.
|
||||||
- `--embedding`: Enable embedding extraction, Default: disabled.
|
- `--embedding`: Enable embedding extraction, Default: disabled.
|
||||||
@ -558,7 +558,7 @@ The HTTP server supports OAI-like API
|
|||||||
|
|
||||||
### Extending or building alternative Web Front End
|
### Extending or building alternative Web Front End
|
||||||
|
|
||||||
The default location for the static files is `examples/server/public`. You can extend the front end by running the server binary with `--path` set to `./your-directory` and importing `/completion.js` to get access to the llamaComplete() method.
|
You can extend the front end by running the server binary with `--path` set to `./your-directory` and importing `/completion.js` to get access to the llamaComplete() method.
|
||||||
|
|
||||||
Read the documentation in `/completion.js` to see convenient ways to access llama.
|
Read the documentation in `/completion.js` to see convenient ways to access llama.
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ struct server_params {
|
|||||||
int32_t n_threads_http = -1;
|
int32_t n_threads_http = -1;
|
||||||
|
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "";
|
||||||
std::string chat_template = "";
|
std::string chat_template = "";
|
||||||
std::string system_prompt = "";
|
std::string system_prompt = "";
|
||||||
|
|
||||||
@ -2145,7 +2145,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
|||||||
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
|
||||||
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||||
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
||||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||||
@ -2211,7 +2211,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.api_keys.emplace_back(argv[i]);
|
sparams.api_keys.push_back(argv[i]);
|
||||||
} else if (arg == "--api-key-file") {
|
} else if (arg == "--api-key-file") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -2712,7 +2712,143 @@ int main(int argc, char ** argv) {
|
|||||||
res.set_header("Access-Control-Allow-Headers", "*");
|
res.set_header("Access-Control-Allow-Headers", "*");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
|
svr->set_logger(log_server_request);
|
||||||
|
|
||||||
|
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||||
|
const char fmt[] = "500 Internal Server Error\n%s";
|
||||||
|
|
||||||
|
char buf[BUFSIZ];
|
||||||
|
try {
|
||||||
|
std::rethrow_exception(std::move(ep));
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
snprintf(buf, sizeof(buf), fmt, e.what());
|
||||||
|
} catch (...) {
|
||||||
|
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
|
||||||
|
}
|
||||||
|
|
||||||
|
res.set_content(buf, "text/plain; charset=utf-8");
|
||||||
|
res.status = 500;
|
||||||
|
});
|
||||||
|
|
||||||
|
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||||
|
if (res.status == 401) {
|
||||||
|
res.set_content("Unauthorized", "text/plain; charset=utf-8");
|
||||||
|
}
|
||||||
|
if (res.status == 400) {
|
||||||
|
res.set_content("Invalid request", "text/plain; charset=utf-8");
|
||||||
|
}
|
||||||
|
if (res.status == 404) {
|
||||||
|
res.set_content("File Not Found", "text/plain; charset=utf-8");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// set timeouts and change hostname and port
|
||||||
|
svr->set_read_timeout (sparams.read_timeout);
|
||||||
|
svr->set_write_timeout(sparams.write_timeout);
|
||||||
|
|
||||||
|
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
|
||||||
|
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::string> log_data;
|
||||||
|
|
||||||
|
log_data["hostname"] = sparams.hostname;
|
||||||
|
log_data["port"] = std::to_string(sparams.port);
|
||||||
|
|
||||||
|
if (sparams.api_keys.size() == 1) {
|
||||||
|
auto key = sparams.api_keys[0];
|
||||||
|
log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
|
||||||
|
} else if (sparams.api_keys.size() > 1) {
|
||||||
|
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
||||||
|
}
|
||||||
|
|
||||||
|
// load the model
|
||||||
|
if (!ctx_server.load_model(params)) {
|
||||||
|
state.store(SERVER_STATE_ERROR);
|
||||||
|
return 1;
|
||||||
|
} else {
|
||||||
|
ctx_server.initialize();
|
||||||
|
state.store(SERVER_STATE_READY);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("model loaded", {});
|
||||||
|
|
||||||
|
const auto model_meta = ctx_server.model_meta();
|
||||||
|
|
||||||
|
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
||||||
|
if (!ctx_server.validate_model_chat_template()) {
|
||||||
|
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||||
|
sparams.chat_template = "chatml";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Middlewares
|
||||||
|
//
|
||||||
|
|
||||||
|
auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
||||||
|
static const std::set<std::string> protected_endpoints = {
|
||||||
|
"/props",
|
||||||
|
"/completion",
|
||||||
|
"/completions",
|
||||||
|
"/v1/completions",
|
||||||
|
"/chat/completions",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
"/infill",
|
||||||
|
"/tokenize",
|
||||||
|
"/detokenize",
|
||||||
|
"/embedding",
|
||||||
|
"/embeddings",
|
||||||
|
"/v1/embeddings",
|
||||||
|
};
|
||||||
|
|
||||||
|
// If API key is not set, skip validation
|
||||||
|
if (sparams.api_keys.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If path is not in protected_endpoints list, skip validation
|
||||||
|
if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for API key in the header
|
||||||
|
auto auth_header = req.get_header_value("Authorization");
|
||||||
|
|
||||||
|
std::string prefix = "Bearer ";
|
||||||
|
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||||
|
std::string received_api_key = auth_header.substr(prefix.size());
|
||||||
|
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
|
||||||
|
return true; // API key is valid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// API key is invalid or not provided
|
||||||
|
// TODO: make another middleware for CORS related logic
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
||||||
|
res.status = 401; // Unauthorized
|
||||||
|
|
||||||
|
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// register server middlewares
|
||||||
|
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (!middleware_validate_api_key(req, res)) {
|
||||||
|
return httplib::Server::HandlerResponse::Handled;
|
||||||
|
}
|
||||||
|
return httplib::Server::HandlerResponse::Unhandled;
|
||||||
|
});
|
||||||
|
|
||||||
|
//
|
||||||
|
// Route handlers (or controllers)
|
||||||
|
//
|
||||||
|
|
||||||
|
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
|
||||||
server_state current_state = state.load();
|
server_state current_state = state.load();
|
||||||
switch (current_state) {
|
switch (current_state) {
|
||||||
case SERVER_STATE_READY:
|
case SERVER_STATE_READY:
|
||||||
@ -2765,10 +2901,15 @@ int main(int argc, char ** argv) {
|
|||||||
res.status = 500; // HTTP Internal Server Error
|
res.status = 500; // HTTP Internal Server Error
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
});
|
};
|
||||||
|
|
||||||
|
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
|
||||||
|
if (!sparams.slots_endpoint) {
|
||||||
|
res.status = 501;
|
||||||
|
res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (sparams.slots_endpoint) {
|
|
||||||
svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
|
|
||||||
// request slots data using task queue
|
// request slots data using task queue
|
||||||
server_task task;
|
server_task task;
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
@ -2785,11 +2926,15 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
res.set_content(result.data["slots"].dump(), "application/json");
|
res.set_content(result.data["slots"].dump(), "application/json");
|
||||||
res.status = 200; // HTTP OK
|
res.status = 200; // HTTP OK
|
||||||
});
|
};
|
||||||
|
|
||||||
|
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||||
|
if (!sparams.metrics_endpoint) {
|
||||||
|
res.status = 501;
|
||||||
|
res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8");
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sparams.metrics_endpoint) {
|
|
||||||
svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
|
|
||||||
// request slots data using task queue
|
// request slots data using task queue
|
||||||
server_task task;
|
server_task task;
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
@ -2883,134 +3028,9 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
|
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
|
||||||
res.status = 200; // HTTP OK
|
res.status = 200; // HTTP OK
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
svr->set_logger(log_server_request);
|
|
||||||
|
|
||||||
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
|
||||||
const char fmt[] = "500 Internal Server Error\n%s";
|
|
||||||
|
|
||||||
char buf[BUFSIZ];
|
|
||||||
try {
|
|
||||||
std::rethrow_exception(std::move(ep));
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
snprintf(buf, sizeof(buf), fmt, e.what());
|
|
||||||
} catch (...) {
|
|
||||||
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
|
|
||||||
}
|
|
||||||
|
|
||||||
res.set_content(buf, "text/plain; charset=utf-8");
|
|
||||||
res.status = 500;
|
|
||||||
});
|
|
||||||
|
|
||||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
|
||||||
if (res.status == 401) {
|
|
||||||
res.set_content("Unauthorized", "text/plain; charset=utf-8");
|
|
||||||
}
|
|
||||||
if (res.status == 400) {
|
|
||||||
res.set_content("Invalid request", "text/plain; charset=utf-8");
|
|
||||||
}
|
|
||||||
if (res.status == 404) {
|
|
||||||
res.set_content("File Not Found", "text/plain; charset=utf-8");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// set timeouts and change hostname and port
|
|
||||||
svr->set_read_timeout (sparams.read_timeout);
|
|
||||||
svr->set_write_timeout(sparams.write_timeout);
|
|
||||||
|
|
||||||
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
|
|
||||||
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the base directory for serving static files
|
|
||||||
svr->set_base_dir(sparams.public_path);
|
|
||||||
|
|
||||||
std::unordered_map<std::string, std::string> log_data;
|
|
||||||
|
|
||||||
log_data["hostname"] = sparams.hostname;
|
|
||||||
log_data["port"] = std::to_string(sparams.port);
|
|
||||||
|
|
||||||
if (sparams.api_keys.size() == 1) {
|
|
||||||
log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
|
|
||||||
} else if (sparams.api_keys.size() > 1) {
|
|
||||||
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
|
||||||
}
|
|
||||||
|
|
||||||
// load the model
|
|
||||||
if (!ctx_server.load_model(params)) {
|
|
||||||
state.store(SERVER_STATE_ERROR);
|
|
||||||
return 1;
|
|
||||||
} else {
|
|
||||||
ctx_server.initialize();
|
|
||||||
state.store(SERVER_STATE_READY);
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_INFO("model loaded", {});
|
|
||||||
|
|
||||||
const auto model_meta = ctx_server.model_meta();
|
|
||||||
|
|
||||||
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
|
||||||
if (!ctx_server.validate_model_chat_template()) {
|
|
||||||
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
|
||||||
sparams.chat_template = "chatml";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Middleware for API key validation
|
|
||||||
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
|
||||||
// If API key is not set, skip validation
|
|
||||||
if (sparams.api_keys.empty()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for API key in the header
|
|
||||||
auto auth_header = req.get_header_value("Authorization");
|
|
||||||
|
|
||||||
std::string prefix = "Bearer ";
|
|
||||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
|
||||||
std::string received_api_key = auth_header.substr(prefix.size());
|
|
||||||
if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
|
|
||||||
return true; // API key is valid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// API key is invalid or not provided
|
|
||||||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
|
||||||
res.status = 401; // Unauthorized
|
|
||||||
|
|
||||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
|
||||||
|
|
||||||
return false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// this is only called if no index.html is found in the public --path
|
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
|
|
||||||
res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
|
|
||||||
// this is only called if no index.js is found in the public --path
|
|
||||||
svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
|
|
||||||
res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
|
|
||||||
// this is only called if no index.html is found in the public --path
|
|
||||||
svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
|
|
||||||
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
|
|
||||||
// this is only called if no index.html is found in the public --path
|
|
||||||
svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
|
|
||||||
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
|
|
||||||
svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
json data = {
|
json data = {
|
||||||
{ "user_name", ctx_server.name_user.c_str() },
|
{ "user_name", ctx_server.name_user.c_str() },
|
||||||
@ -3020,13 +3040,10 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
});
|
};
|
||||||
|
|
||||||
const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!validate_api_key(req, res)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
|
||||||
@ -3102,11 +3119,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
svr->Post("/completion", completions); // legacy
|
const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
||||||
svr->Post("/completions", completions);
|
|
||||||
svr->Post("/v1/completions", completions);
|
|
||||||
|
|
||||||
svr->Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
json models = {
|
json models = {
|
||||||
@ -3123,14 +3136,10 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||||
});
|
};
|
||||||
|
|
||||||
const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!validate_api_key(req, res)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
|
||||||
|
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||||
@ -3201,14 +3210,8 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
svr->Post("/chat/completions", chat_completions);
|
const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
svr->Post("/v1/chat/completions", chat_completions);
|
|
||||||
|
|
||||||
svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!validate_api_key(req, res)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
|
||||||
@ -3266,13 +3269,9 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||||
}
|
}
|
||||||
});
|
};
|
||||||
|
|
||||||
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
|
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
return res.set_content("", "application/json; charset=utf-8");
|
|
||||||
});
|
|
||||||
|
|
||||||
svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
@ -3282,9 +3281,9 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
const json data = format_tokenizer_response(tokens);
|
const json data = format_tokenizer_response(tokens);
|
||||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
});
|
};
|
||||||
|
|
||||||
svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
@ -3296,83 +3295,45 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const json data = format_detokenized_response(content);
|
const json data = format_detokenized_response(content);
|
||||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
});
|
|
||||||
|
|
||||||
svr->Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
||||||
if (!params.embedding) {
|
|
||||||
res.status = 501;
|
|
||||||
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const json body = json::parse(req.body);
|
|
||||||
|
|
||||||
json prompt;
|
|
||||||
if (body.count("content") != 0) {
|
|
||||||
prompt = body["content"];
|
|
||||||
} else {
|
|
||||||
prompt = "";
|
|
||||||
}
|
|
||||||
|
|
||||||
// create and queue the task
|
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
||||||
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
||||||
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true);
|
|
||||||
|
|
||||||
// get the result
|
|
||||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
||||||
|
|
||||||
// send the result
|
|
||||||
return res.set_content(result.data.dump(), "application/json; charset=utf-8");
|
|
||||||
});
|
|
||||||
|
|
||||||
svr->Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
||||||
if (!params.embedding) {
|
|
||||||
res.status = 501;
|
|
||||||
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const json body = json::parse(req.body);
|
|
||||||
|
|
||||||
json prompt;
|
|
||||||
if (body.count("input") != 0) {
|
|
||||||
prompt = body["input"];
|
|
||||||
if (prompt.is_array()) {
|
|
||||||
json data = json::array();
|
|
||||||
|
|
||||||
int i = 0;
|
|
||||||
for (const json & elem : prompt) {
|
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
||||||
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
||||||
ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true);
|
|
||||||
|
|
||||||
// get the result
|
|
||||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
||||||
|
|
||||||
json embedding = json{
|
|
||||||
{"embedding", json_value(result.data, "embedding", json::array())},
|
|
||||||
{"index", i++},
|
|
||||||
{"object", "embedding"}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
data.push_back(embedding);
|
const auto handle_embeddings = [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
if (!params.embedding) {
|
||||||
|
res.status = 501;
|
||||||
|
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json result = format_embeddings_response_oaicompat(body, data);
|
const json body = json::parse(req.body);
|
||||||
|
bool is_openai = false;
|
||||||
|
|
||||||
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
// an input prompt can string or a list of tokens (integer)
|
||||||
|
std::vector<json> prompts;
|
||||||
|
if (body.count("input") != 0) {
|
||||||
|
is_openai = true;
|
||||||
|
if (body["input"].is_array()) {
|
||||||
|
// support multiple prompts
|
||||||
|
for (const json & elem : body["input"]) {
|
||||||
|
prompts.push_back(elem);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
prompt = "";
|
// single input prompt
|
||||||
|
prompts.push_back(body["input"]);
|
||||||
|
}
|
||||||
|
} else if (body.count("content") != 0) {
|
||||||
|
// only support single prompt here
|
||||||
|
std::string content = body["content"];
|
||||||
|
prompts.push_back(content);
|
||||||
|
} else {
|
||||||
|
// TODO @ngxson : should return an error here
|
||||||
|
prompts.push_back("");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// process all prompts
|
||||||
|
json responses = json::array();
|
||||||
|
for (auto & prompt : prompts) {
|
||||||
|
// TODO @ngxson : maybe support multitask for this endpoint?
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||||
|
|
||||||
@ -3382,19 +3343,77 @@ int main(int argc, char ** argv) {
|
|||||||
// get the result
|
// get the result
|
||||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
responses.push_back(result.data);
|
||||||
|
}
|
||||||
|
|
||||||
json data = json::array({json{
|
// write JSON response
|
||||||
{"embedding", json_value(result.data, "embedding", json::array())},
|
json root;
|
||||||
{"index", 0},
|
if (is_openai) {
|
||||||
|
json res_oai = json::array();
|
||||||
|
int i = 0;
|
||||||
|
for (auto & elem : responses) {
|
||||||
|
res_oai.push_back(json{
|
||||||
|
{"embedding", json_value(elem, "embedding", json::array())},
|
||||||
|
{"index", i++},
|
||||||
{"object", "embedding"}
|
{"object", "embedding"}
|
||||||
}}
|
|
||||||
);
|
|
||||||
|
|
||||||
json root = format_embeddings_response_oaicompat(body, data);
|
|
||||||
|
|
||||||
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
root = format_embeddings_response_oaicompat(body, res_oai);
|
||||||
|
} else {
|
||||||
|
root = responses[0];
|
||||||
|
}
|
||||||
|
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// Router
|
||||||
|
//
|
||||||
|
|
||||||
|
// register static assets routes
|
||||||
|
if (!sparams.public_path.empty()) {
|
||||||
|
// Set the base directory for serving static files
|
||||||
|
svr->set_base_dir(sparams.public_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// using embedded static files
|
||||||
|
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
||||||
|
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
||||||
|
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||||
|
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
|
||||||
|
return res.set_content("", "application/json; charset=utf-8");
|
||||||
|
});
|
||||||
|
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||||
|
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
||||||
|
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||||
|
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
|
||||||
|
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
|
||||||
|
|
||||||
|
// register API routes
|
||||||
|
svr->Get ("/health", handle_health);
|
||||||
|
svr->Get ("/slots", handle_slots);
|
||||||
|
svr->Get ("/metrics", handle_metrics);
|
||||||
|
svr->Get ("/props", handle_props);
|
||||||
|
svr->Get ("/v1/models", handle_models);
|
||||||
|
svr->Post("/completion", handle_completions); // legacy
|
||||||
|
svr->Post("/completions", handle_completions);
|
||||||
|
svr->Post("/v1/completions", handle_completions);
|
||||||
|
svr->Post("/chat/completions", handle_chat_completions);
|
||||||
|
svr->Post("/v1/chat/completions", handle_chat_completions);
|
||||||
|
svr->Post("/infill", handle_infill);
|
||||||
|
svr->Post("/embedding", handle_embeddings); // legacy
|
||||||
|
svr->Post("/embeddings", handle_embeddings);
|
||||||
|
svr->Post("/v1/embeddings", handle_embeddings);
|
||||||
|
svr->Post("/tokenize", handle_tokenize);
|
||||||
|
svr->Post("/detokenize", handle_detokenize);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Start the server
|
||||||
|
//
|
||||||
if (sparams.n_threads_http < 1) {
|
if (sparams.n_threads_http < 1) {
|
||||||
// +2 threads for monitoring endpoints
|
// +2 threads for monitoring endpoints
|
||||||
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||||
|
@ -39,6 +39,7 @@ Feature: Security
|
|||||||
|
|
||||||
|
|
||||||
Scenario Outline: CORS Options
|
Scenario Outline: CORS Options
|
||||||
|
Given a user api key llama.cpp
|
||||||
When an OPTIONS request is sent from <origin>
|
When an OPTIONS request is sent from <origin>
|
||||||
Then CORS header <cors_header> is set to <cors_header_value>
|
Then CORS header <cors_header> is set to <cors_header_value>
|
||||||
|
|
||||||
|
@ -582,8 +582,9 @@ async def step_detokenize(context):
|
|||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_options_request(context, origin):
|
async def step_options_request(context, origin):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
headers = {'Authorization': f'Bearer {context.user_api_key}', 'Origin': origin}
|
||||||
async with session.options(f'{context.base_url}/v1/chat/completions',
|
async with session.options(f'{context.base_url}/v1/chat/completions',
|
||||||
headers={"Origin": origin}) as response:
|
headers=headers) as response:
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
context.options_response = response
|
context.options_response = response
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user