mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
server : add optional API Key Authentication example (#4441)
* Add API key authentication for enhanced server-client security * server : to snake_case --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
ee4725a686
commit
88ae8952b6
@ -34,7 +34,8 @@ export async function* llama(prompt, params = {}, config = {}) {
|
|||||||
headers: {
|
headers: {
|
||||||
'Connection': 'keep-alive',
|
'Connection': 'keep-alive',
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Accept': 'text/event-stream'
|
'Accept': 'text/event-stream',
|
||||||
|
...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {})
|
||||||
},
|
},
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
});
|
});
|
||||||
|
@ -235,7 +235,8 @@
|
|||||||
grammar: '',
|
grammar: '',
|
||||||
n_probs: 0, // no completion_probabilities,
|
n_probs: 0, // no completion_probabilities,
|
||||||
image_data: [],
|
image_data: [],
|
||||||
cache_prompt: true
|
cache_prompt: true,
|
||||||
|
api_key: ''
|
||||||
})
|
})
|
||||||
|
|
||||||
/* START: Support for storing prompt templates and parameters in browsers LocalStorage */
|
/* START: Support for storing prompt templates and parameters in browsers LocalStorage */
|
||||||
@ -790,6 +791,10 @@
|
|||||||
<fieldset>
|
<fieldset>
|
||||||
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
|
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
|
<fieldset>
|
||||||
|
<label for="api_key">API Key</label>
|
||||||
|
<input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />
|
||||||
|
</fieldset>
|
||||||
</details>
|
</details>
|
||||||
</form>
|
</form>
|
||||||
`
|
`
|
||||||
|
@ -36,6 +36,7 @@ using json = nlohmann::json;
|
|||||||
struct server_params
|
struct server_params
|
||||||
{
|
{
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
|
std::string api_key;
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "examples/server/public";
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
@ -1953,6 +1954,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||||||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
||||||
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||||
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
||||||
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
||||||
@ -2002,6 +2004,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
}
|
}
|
||||||
sparams.public_path = argv[i];
|
sparams.public_path = argv[i];
|
||||||
}
|
}
|
||||||
|
else if (arg == "--api-key")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.api_key = argv[i];
|
||||||
|
}
|
||||||
else if (arg == "--timeout" || arg == "-to")
|
else if (arg == "--timeout" || arg == "-to")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
@ -2669,6 +2680,32 @@ int main(int argc, char **argv)
|
|||||||
|
|
||||||
httplib::Server svr;
|
httplib::Server svr;
|
||||||
|
|
||||||
|
// Middleware for API key validation
|
||||||
|
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
||||||
|
// If API key is not set, skip validation
|
||||||
|
if (sparams.api_key.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for API key in the header
|
||||||
|
auto auth_header = req.get_header_value("Authorization");
|
||||||
|
std::string prefix = "Bearer ";
|
||||||
|
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||||
|
std::string received_api_key = auth_header.substr(prefix.size());
|
||||||
|
if (received_api_key == sparams.api_key) {
|
||||||
|
return true; // API key is valid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// API key is invalid or not provided
|
||||||
|
res.set_content("Unauthorized: Invalid API Key", "text/plain");
|
||||||
|
res.status = 401; // Unauthorized
|
||||||
|
|
||||||
|
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
svr.set_default_headers({{"Server", "llama.cpp"},
|
svr.set_default_headers({{"Server", "llama.cpp"},
|
||||||
{"Access-Control-Allow-Origin", "*"},
|
{"Access-Control-Allow-Origin", "*"},
|
||||||
{"Access-Control-Allow-Headers", "content-type"}});
|
{"Access-Control-Allow-Headers", "content-type"}});
|
||||||
@ -2711,8 +2748,11 @@ int main(int argc, char **argv)
|
|||||||
res.set_content(data.dump(), "application/json");
|
res.set_content(data.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
if (!validate_api_key(req, res)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, false, false, -1);
|
const int task_id = llama.request_completion(data, false, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
@ -2799,8 +2839,11 @@ int main(int argc, char **argv)
|
|||||||
});
|
});
|
||||||
|
|
||||||
// TODO: add mount point without "/v1" prefix -- how?
|
// TODO: add mount point without "/v1" prefix -- how?
|
||||||
svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
if (!validate_api_key(req, res)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||||
|
|
||||||
const int task_id = llama.request_completion(data, false, false, -1);
|
const int task_id = llama.request_completion(data, false, false, -1);
|
||||||
@ -2869,8 +2912,11 @@ int main(int argc, char **argv)
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
if (!validate_api_key(req, res)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, true, false, -1);
|
const int task_id = llama.request_completion(data, true, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
@ -3005,11 +3051,15 @@ int main(int argc, char **argv)
|
|||||||
|
|
||||||
svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
|
svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
if (res.status == 401)
|
||||||
|
{
|
||||||
|
res.set_content("Unauthorized", "text/plain");
|
||||||
|
}
|
||||||
if (res.status == 400)
|
if (res.status == 400)
|
||||||
{
|
{
|
||||||
res.set_content("Invalid request", "text/plain");
|
res.set_content("Invalid request", "text/plain");
|
||||||
}
|
}
|
||||||
else if (res.status != 500)
|
else if (res.status == 404)
|
||||||
{
|
{
|
||||||
res.set_content("File Not Found", "text/plain");
|
res.set_content("File Not Found", "text/plain");
|
||||||
res.status = 404;
|
res.status = 404;
|
||||||
@ -3032,11 +3082,15 @@ int main(int argc, char **argv)
|
|||||||
// to make it ctrl+clickable:
|
// to make it ctrl+clickable:
|
||||||
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
|
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
|
||||||
|
|
||||||
LOG_INFO("HTTP server listening", {
|
std::unordered_map<std::string, std::string> log_data;
|
||||||
{"hostname", sparams.hostname},
|
log_data["hostname"] = sparams.hostname;
|
||||||
{"port", sparams.port},
|
log_data["port"] = std::to_string(sparams.port);
|
||||||
});
|
|
||||||
|
|
||||||
|
if (!sparams.api_key.empty()) {
|
||||||
|
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("HTTP server listening", log_data);
|
||||||
// run the HTTP server in a thread - see comment below
|
// run the HTTP server in a thread - see comment below
|
||||||
std::thread t([&]()
|
std::thread t([&]()
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user