From cd108e641dbdedd8c5641c4cec1762f751f38136 Mon Sep 17 00:00:00 2001 From: Behnam M <58621210+ibehnam@users.noreply.github.com> Date: Wed, 10 Jan 2024 14:56:05 -0500 Subject: [PATCH] server : add a `/health` endpoint (#4860) * added /health endpoint to the server * added comments on the additional /health endpoint * Better handling of server state When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value. * initialized server_state * fixed a typo * starting http server before initializing the model * Update server.cpp * Update server.cpp * fixes * fixes * fixes * made ServerState atomic and turned two-line spaces into one-line --- examples/server/server.cpp | 199 +++++++++++++++++++++---------------- 1 file changed, 113 insertions(+), 86 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6c7fcd176..1cca634d5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 @@ -146,6 +147,12 @@ static std::vector base64_decode(const std::string & encoded_string) // parallel // +enum ServerState { + LOADING_MODEL, // Server is starting up, model not fully loaded yet + READY, // Server is ready and model is loaded + ERROR // An error occurred, load_model failed +}; + enum task_type { COMPLETION_TASK, CANCEL_TASK @@ -2453,7 +2460,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } - static std::string random_string() { static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); @@ -2790,15 +2796,117 @@ int main(int argc, char **argv) {"system_info", llama_print_system_info()}, }); - // load the model - if (!llama.load_model(params)) + httplib::Server svr; + + std::atomic server_state{LOADING_MODEL}; + + svr.set_default_headers({{"Server", "llama.cpp"}, + {"Access-Control-Allow-Origin", "*"}, + {"Access-Control-Allow-Headers", "content-type"}}); + + svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) { + ServerState current_state = server_state.load(); + switch(current_state) { + case READY: + res.set_content(R"({"status": "ok"})", "application/json"); + res.status = 200; // HTTP OK + break; + case LOADING_MODEL: + res.set_content(R"({"status": "loading model"})", "application/json"); + res.status = 503; // HTTP Service Unavailable + break; + case ERROR: + res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); + res.status = 500; // HTTP Internal Server Error + break; + } + }); + + svr.set_logger(log_server_request); + + svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) + { + const char fmt[] = "500 Internal Server Error\n%s"; + char buf[BUFSIZ]; + try + { + std::rethrow_exception(std::move(ep)); + } + catch (std::exception &e) + { + snprintf(buf, sizeof(buf), fmt, e.what()); + } + catch (...) + { + snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); + } + res.set_content(buf, "text/plain; charset=utf-8"); + res.status = 500; + }); + + svr.set_error_handler([](const httplib::Request &, httplib::Response &res) + { + if (res.status == 401) + { + res.set_content("Unauthorized", "text/plain; charset=utf-8"); + } + if (res.status == 400) + { + res.set_content("Invalid request", "text/plain; charset=utf-8"); + } + else if (res.status == 404) + { + res.set_content("File Not Found", "text/plain; charset=utf-8"); + res.status = 404; + } + }); + + // set timeouts and change hostname and port + svr.set_read_timeout (sparams.read_timeout); + svr.set_write_timeout(sparams.write_timeout); + + if (!svr.bind_to_port(sparams.hostname, sparams.port)) { + fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); return 1; } - llama.initialize(); + // Set the base directory for serving static files + svr.set_base_dir(sparams.public_path); - httplib::Server svr; + // to make it ctrl+clickable: + LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); + + std::unordered_map log_data; + log_data["hostname"] = sparams.hostname; + log_data["port"] = std::to_string(sparams.port); + + if (!sparams.api_key.empty()) { + log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); + } + + LOG_INFO("HTTP server listening", log_data); + // run the HTTP server in a thread - see comment below + std::thread t([&]() + { + if (!svr.listen_after_bind()) + { + server_state.store(ERROR); + return 1; + } + + return 0; + }); + + // load the model + if (!llama.load_model(params)) + { + server_state.store(ERROR); + return 1; + } else { + llama.initialize(); + server_state.store(READY); + } // Middleware for API key validation auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { @@ -2826,10 +2934,6 @@ int main(int argc, char **argv) return false; }; - svr.set_default_headers({{"Server", "llama.cpp"}, - {"Access-Control-Allow-Origin", "*"}, - {"Access-Control-Allow-Headers", "content-type"}}); - // this is only called if no index.html is found in the public --path svr.Get("/", [](const httplib::Request &, httplib::Response &res) { @@ -2937,8 +3041,6 @@ int main(int argc, char **argv) } }); - - svr.Get("/v1/models", [¶ms](const httplib::Request&, httplib::Response& res) { std::time_t t = std::time(0); @@ -3157,81 +3259,6 @@ int main(int argc, char **argv) return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); - svr.set_logger(log_server_request); - - svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) - { - const char fmt[] = "500 Internal Server Error\n%s"; - char buf[BUFSIZ]; - try - { - std::rethrow_exception(std::move(ep)); - } - catch (std::exception &e) - { - snprintf(buf, sizeof(buf), fmt, e.what()); - } - catch (...) - { - snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); - } - res.set_content(buf, "text/plain; charset=utf-8"); - res.status = 500; - }); - - svr.set_error_handler([](const httplib::Request &, httplib::Response &res) - { - if (res.status == 401) - { - res.set_content("Unauthorized", "text/plain; charset=utf-8"); - } - if (res.status == 400) - { - res.set_content("Invalid request", "text/plain; charset=utf-8"); - } - else if (res.status == 404) - { - res.set_content("File Not Found", "text/plain; charset=utf-8"); - res.status = 404; - } - }); - - // set timeouts and change hostname and port - svr.set_read_timeout (sparams.read_timeout); - svr.set_write_timeout(sparams.write_timeout); - - if (!svr.bind_to_port(sparams.hostname, sparams.port)) - { - fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); - return 1; - } - - // Set the base directory for serving static files - svr.set_base_dir(sparams.public_path); - - // to make it ctrl+clickable: - LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); - - std::unordered_map log_data; - log_data["hostname"] = sparams.hostname; - log_data["port"] = std::to_string(sparams.port); - - if (!sparams.api_key.empty()) { - log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); - } - - LOG_INFO("HTTP server listening", log_data); - // run the HTTP server in a thread - see comment below - std::thread t([&]() - { - if (!svr.listen_after_bind()) - { - return 1; - } - - return 0; - }); - // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // "Bus error: 10" - this is on macOS, it does not crash on Linux //std::thread t2([&]()