From 3573fa8e7b7f0865638b52b4e9b4d2006f0558a2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 7 Dec 2024 20:21:09 +0100 Subject: [PATCH] server : (refactor) no more json in server_task input (#10691) * server : (refactor) no more json in server_task input * add test for slots endpoint * add tests for /props and /slots * remove task inf_type * fix CI by adding safe_json_to_str * add "model_path" to /props * update readme --- examples/server/README.md | 2 + examples/server/server.cpp | 744 +++++++++--------- examples/server/tests/unit/test_basic.py | 30 + .../server/tests/unit/test_chat_completion.py | 5 + examples/server/tests/utils.py | 10 +- examples/server/utils.hpp | 20 +- 6 files changed, 427 insertions(+), 384 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 0bab40a82..117c52d3f 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make } }, "total_slots": 1, + "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", "chat_template": "..." } ``` - `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint. - `total_slots` - the total number of slots for process requests (defined by `--parallel` option) +- `model_path` - the path to model file (same with `-m` argument) - `chat_template` - the model's original Jinja2 prompt template ### POST `/props`: Change server global properties. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1ce8fbae2..1c21e55aa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -54,7 +54,10 @@ enum server_state { }; enum server_task_type { - SERVER_TASK_TYPE_INFERENCE, + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_METRICS, @@ -64,13 +67,6 @@ enum server_task_type { SERVER_TASK_TYPE_SET_LORA, }; -enum server_task_inf_type { - SERVER_TASK_INF_TYPE_COMPLETION, - SERVER_TASK_INF_TYPE_EMBEDDING, - SERVER_TASK_INF_TYPE_RERANK, - SERVER_TASK_INF_TYPE_INFILL, -}; - // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 enum error_type { ERROR_TYPE_INVALID_REQUEST, @@ -82,28 +78,6 @@ enum error_type { ERROR_TYPE_NOT_SUPPORTED, // custom error }; -struct server_task { - int id = -1; // to be filled by server_queue - int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL - - llama_tokens prompt_tokens; - server_task_type type; - - // TODO @ngxson : we should get rid of json type here - json data; - - server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION; - - // utility function - static std::unordered_set get_list_id(const std::vector & tasks) { - std::unordered_set ids(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - ids.insert(tasks[i].id); - } - return ids; - } -}; - struct slot_params { bool stream = true; bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt @@ -118,6 +92,7 @@ struct slot_params { std::vector antiprompt; bool timings_per_token = false; + bool ignore_eos = false; struct common_params_sampling sampling; struct common_params_speculative speculative; @@ -167,7 +142,7 @@ struct slot_params { {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, {"stream", stream}, - //{"logit_bias", sampling.logit_bias}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, @@ -180,6 +155,209 @@ struct slot_params { } }; +struct server_task { + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) + + server_task_type type; + + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + llama_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + server_task(server_task_type type) : type(type) {} + + static slot_params params_from_json_cmpl( + const llama_model * model, + const common_params & params_base, + const json & data) { + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 2); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + params.sampling.grammar = json_schema_to_grammar(schema); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_n_vocab(model); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(model, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto & samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + std::vector sampler_names; + for (const auto & name : *samplers) { + if (name.is_string()) { + sampler_names.emplace_back(name); + } + } + params.sampling.samplers = common_sampler_types_from_names(sampler_names, false); + } else if (samplers->is_string()){ + std::string sampler_string; + for (const auto & name : *samplers) { + sampler_string += name; + } + params.sampling.samplers = common_sampler_types_from_chars(sampler_string); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } +}; + struct result_timings { int32_t prompt_n = -1; double prompt_ms; @@ -191,7 +369,7 @@ struct result_timings { double predicted_per_token_ms; double predicted_per_second; - json to_json() { + json to_json() const { return { {"prompt_n", prompt_n}, {"prompt_ms", prompt_ms}, @@ -623,7 +801,8 @@ struct server_task_result_metrics : server_task_result { uint64_t n_decode_total = 0; uint64_t n_busy_slots_total = 0; - // TODO: get rid of this json object and use to_json() instead + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result json slots_data = json::array(); virtual json to_json() override { @@ -708,6 +887,9 @@ struct server_slot { int id; int id_task = -1; + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + llama_batch batch_spec = {}; llama_context * ctx = nullptr; @@ -746,8 +928,6 @@ struct server_slot { llama_tokens cache_tokens; std::vector generated_token_probs; - server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION; - bool has_next_token = true; bool has_new_line = false; bool truncated = false; @@ -787,11 +967,15 @@ struct server_slot { n_past = 0; n_sent_text = 0; n_sent_token_probs = 0; - inf_type = SERVER_TASK_INF_TYPE_COMPLETION; + task_type = SERVER_TASK_TYPE_COMPLETION; generated_token_probs.clear(); } + bool is_non_causal() const { + return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + } + bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless @@ -903,6 +1087,7 @@ struct server_slot { {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, + {"non_causal", is_non_causal()}, {"params", params.to_json()}, {"prompt", common_detokenize(ctx, prompt_tokens)}, {"next_token", @@ -988,9 +1173,7 @@ struct server_queue { // Add a new task to the end of the queue int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); - if (task.id == -1) { - task.id = id++; - } + GGML_ASSERT(task.id != -1); QUE_DBG("new task, id = %d, front = %d\n", task.id, front); if (front) { queue_tasks.push_front(std::move(task)); @@ -1468,104 +1651,14 @@ struct server_context { } bool launch_slot_with_task(server_slot & slot, const server_task & task) { - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - slot_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; + slot.reset(); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); - const auto & data = task.data; - - if (data.count("__oaicompat") != 0) { - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; - slot.params.oaicompat = true; - slot.params.oaicompat_chat = json_value(data, "__oaicompat_chat", false); - slot.params.oaicompat_model = json_value(data, "model", model_name); - slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string()); - } else { - slot.params.oaicompat = false; - } - - - // enabling this will output extra debug information in the HTTP responses from the server - slot.params.verbose = params_base.verbosity > 9; - slot.params.timings_per_token = json_value(data, "timings_per_token", false); - - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", true); - slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); - slot.params.n_indent = json_value(data, "n_indent", defaults.n_indent); - slot.params.n_keep = json_value(data, "n_keep", defaults.n_keep); - slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard); - //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement - slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - - slot.params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - slot.params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - slot.params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - slot.params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - slot.params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - slot.params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - slot.params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - slot.params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - slot.params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - slot.params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - slot.params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - slot.params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - slot.params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - slot.params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - slot.params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - slot.params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - slot.params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl); - slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - - slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); - slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); - slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); - - slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); - slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2); - slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); - - if (slot.params.sampling.dry_base < 1.0f) { - slot.params.sampling.dry_base = defaults.sampling.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (slot.params.sampling.dry_sequence_breakers.empty()) { - send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { - send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); - return false; - } - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - slot.params.sampling.grammar = json_schema_to_grammar(schema); - } catch (const std::exception & e) { - send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST); - return false; - } - } else { - slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? @@ -1573,78 +1666,8 @@ struct server_context { SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); } - { - slot.params.sampling.logit_bias.clear(); - - if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY}); - } - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(model); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - slot.params.sampling.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = common_tokenize(model, el[0].get(), false); - for (auto tok : toks) { - slot.params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - } - } - - { - slot.params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - slot.params.antiprompt.push_back(word); - } - } - } - } - - { - const auto & samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - std::vector sampler_names; - for (const auto & name : *samplers) { - if (name.is_string()) { - sampler_names.emplace_back(name); - } - } - slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false); - } else if (samplers->is_string()){ - std::string sampler_string; - for (const auto & name : *samplers) { - sampler_string += name; - } - slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string); - } - } else { - slot.params.sampling.samplers = defaults.sampling.samplers; - } + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY}); } { @@ -2020,82 +2043,13 @@ struct server_context { // Functions to create new task(s) and receive result(s) // - // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s) - std::vector create_tasks_inference(json data, server_task_inf_type inf_type) { - std::vector tasks; - auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) { - SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size()); - - server_task task; - task.id = queue_tasks.get_new_id(); - task.inf_type = inf_type; - task.type = SERVER_TASK_TYPE_INFERENCE; - task.data = task_data; - task.prompt_tokens = std::move(prompt_tokens); - tasks.push_back(std::move(task)); - }; - - static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts"; - if (!data.contains("prompt")) { - throw std::runtime_error(error_msg); - } - - // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread - bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL; - std::vector tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true); - switch (inf_type) { - case SERVER_TASK_INF_TYPE_RERANK: - { - // prompts[0] is the question - // the rest are the answers/documents - GGML_ASSERT(tokenized_prompts.size() > 1); - SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1); - for (size_t i = 1; i < tokenized_prompts.size(); i++) { - data["index"] = i - 1; - auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]); - create_task(data, tokens); - } - } break; - case SERVER_TASK_INF_TYPE_INFILL: - { - SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - data["index"] = i; - auto tokens = format_infill( - ctx, - data.at("input_prefix"), - data.at("input_suffix"), - data.at("input_extra"), - params_base.n_batch, - params_base.n_predict, - slots[0].n_ctx, // TODO: there should be a better way - params_base.spm_infill, - tokenized_prompts[i] - ); - create_task(data, tokens); - } - } break; - default: - { - SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - data["index"] = i; - create_task(data, tokenized_prompts[i]); - } - } - } - - return tasks; - } - void cancel_tasks(const std::unordered_set & id_tasks) { std::vector cancel_tasks; cancel_tasks.reserve(id_tasks.size()); for (const auto & id_task : id_tasks) { SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task; - task.type = SERVER_TASK_TYPE_CANCEL; + server_task task(SERVER_TASK_TYPE_CANCEL); task.id_target = id_task; cancel_tasks.push_back(task); queue_results.remove_waiting_task_id(id_task); @@ -2104,7 +2058,7 @@ struct server_context { queue_tasks.post(cancel_tasks, true); } - // receive the results from task(s) created by create_tasks_inference + // receive the results from task(s) void receive_multi_results( const std::unordered_set & id_tasks, const std::function&)> & result_handler, @@ -2131,7 +2085,7 @@ struct server_context { result_handler(results); } - // receive the results from task(s) created by create_tasks_inference, in stream mode + // receive the results from task(s), in stream mode void receive_cmpl_results_stream( const std::unordered_set & id_tasks, const std::function & result_handler, @@ -2166,9 +2120,12 @@ struct server_context { void process_single_task(server_task task) { switch (task.type) { - case SERVER_TASK_TYPE_INFERENCE: + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: { - const int id_slot = json_value(task.data, "id_slot", -1); + const int id_slot = task.id_selected_slot; server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); @@ -2185,13 +2142,6 @@ struct server_context { break; } - slot->reset(); - - slot->id_task = task.id; - slot->inf_type = task.inf_type; - slot->index = json_value(task.data, "index", 0); - slot->prompt_tokens = std::move(task.prompt_tokens); - if (!launch_slot_with_task(*slot, task)) { SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); break; @@ -2255,14 +2205,14 @@ struct server_context { res->n_decode_total = metrics.n_decode_total; res->n_busy_slots_total = metrics.n_busy_slots_total; - if (json_value(task.data, "reset_bucket", false)) { + if (task.metrics_reset_bucket) { metrics.reset_bucket(); } queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_SAVE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -2278,8 +2228,8 @@ struct server_context { const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); @@ -2298,7 +2248,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -2313,8 +2263,8 @@ struct server_context { const int64_t t_start = ggml_time_us(); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; slot->cache_tokens.resize(slot->n_ctx); size_t token_count = 0; @@ -2341,7 +2291,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -2400,10 +2350,8 @@ struct server_context { { SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - server_task task; - task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; - task.id_target = -1; - + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); queue_tasks.post(task); } @@ -2517,7 +2465,7 @@ struct server_context { continue; } - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.is_non_causal()) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); @@ -2632,7 +2580,7 @@ struct server_context { } // non-causal tasks require to fit the entire prompt in the physical batch - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2640,10 +2588,7 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - const bool slot_type = - slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || - slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0; - + int slot_type = slot.is_non_causal(); if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2760,7 +2705,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -2768,7 +2713,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -2998,12 +2943,12 @@ int main(int argc, char ** argv) { auto res_error = [](httplib::Response & res, const json & error_data) { json final_response {{"error", error_data}}; - res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); res.status = json_value(error_data, "code", 500); }; auto res_ok = [](httplib::Response & res, const json & data) { - res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.set_content(safe_json_to_str(data), MIMETYPE_JSON); res.status = 200; }; @@ -3140,10 +3085,8 @@ int main(int argc, char ** argv) { } // request slots data using task queue - server_task task; + server_task task(SERVER_TASK_TYPE_METRICS); task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task, true); // high-priority task @@ -3178,11 +3121,9 @@ int main(int argc, char ** argv) { } // request slots data using task queue - server_task task; + server_task task(SERVER_TASK_TYPE_METRICS); task.id = ctx_server.queue_tasks.get_new_id(); - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - task.data.push_back({{"reset_bucket", true}}); + task.metrics_reset_bucket = true; ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task, true); // high-priority task @@ -3286,19 +3227,17 @@ int main(int argc, char ** argv) { } std::string filepath = params.slot_save_path + filename; - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_SAVE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result_ptr result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { res_error(res, result->to_json()); @@ -3317,19 +3256,17 @@ int main(int argc, char ** argv) { } std::string filepath = params.slot_save_path + filename; - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_RESTORE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result_ptr result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { res_error(res, result->to_json()); @@ -3341,17 +3278,15 @@ int main(int argc, char ** argv) { }; const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_ERASE; - task.data = { - { "id_slot", id_slot }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result_ptr result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { res_error(res, result->to_json()); @@ -3392,9 +3327,11 @@ int main(int argc, char ** argv) { }; const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + // this endpoint is publicly available, please only return what is safe to be exposed json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, + { "model_path", ctx_server.params_base.model }, { "chat_template", llama_get_chat_template(ctx_server.model) }, }; @@ -3417,17 +3354,47 @@ int main(int argc, char ** argv) { // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok]( - server_task_inf_type inf_type, + server_task_type type, json & data, httplib::Response & res, - bool oai_compat = false) { + bool oaicompat = false, + bool oaicompat_chat = false) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - data["completion_id"] = gen_chatcmplid(); - std::vector tasks = ctx_server.create_tasks_inference(data, inf_type); + auto completion_id = gen_chatcmplid(); + std::vector tasks; + + try { + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true); + tasks.reserve(tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(tokenized_prompts[i]); + task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.oaicompat_chat = oaicompat_chat; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(task); + } + } catch (const std::exception & e) { + res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3453,7 +3420,7 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { json res_json = result->to_json(); if (res_json.is_array()) { @@ -3469,7 +3436,7 @@ int main(int argc, char ** argv) { }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); - if (oai_compat) { + if (oaicompat) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); } @@ -3487,7 +3454,12 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res); + return handle_completions_generic( + SERVER_TASK_TYPE_COMPLETION, + data, + res, + /* oaicompat */ false, + /* oaicompat_chat */ false); }; const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -3537,7 +3509,7 @@ int main(int argc, char ** argv) { } data["input_extra"] = input_extra; // default to empty array if it's not exist - return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); + return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -3547,11 +3519,15 @@ int main(int argc, char ** argv) { } json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - data["__oaicompat_chat"] = true; - return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true); + return handle_completions_generic( + SERVER_TASK_TYPE_COMPLETION, + data, + res, + /* oaicompat */ true, + /* oaicompat_chat */ true); }; - const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { + const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, {"data", { @@ -3565,7 +3541,7 @@ int main(int argc, char ** argv) { }} }; - res.set_content(models.dump(), MIMETYPE_JSON); + res_ok(res, models); }; const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { @@ -3642,7 +3618,16 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING); + std::vector tasks; + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = std::move(tokenized_prompts[i]); + tasks.push_back(task); + } + ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3669,7 +3654,7 @@ int main(int argc, char ** argv) { // write JSON response json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) - : responses[0]; + : responses.size() == 1 ? responses[0] : json(responses); res_ok(res, root); }; @@ -3708,20 +3693,23 @@ int main(int argc, char ** argv) { return; } - // construct prompt object: array of ["query", "doc0", "doc1", ...] - json prompt; - prompt.push_back(query); - for (const auto & doc : documents) { - prompt.push_back(doc); - } - - LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str()); + llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0]; // create and queue the task json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK); + std::vector tasks; + std::vector tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true); + tasks.reserve(tokenized_docs.size()); + for (size_t i = 0; i < tokenized_docs.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]); + tasks.push_back(task); + } + ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3782,13 +3770,13 @@ int main(int argc, char ** argv) { } } - server_task task; - task.type = SERVER_TASK_TYPE_SET_LORA; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = ctx_server.queue_tasks.get_new_id(); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result_ptr result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); if (result->is_error()) { res_error(res, result->to_json()); diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index d82d54a5a..1d5124016 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -22,7 +22,12 @@ def test_server_props(): server.start() res = server.make_request("GET", "/props") assert res.status_code == 200 + assert ".gguf" in res.body["model_path"] assert res.body["total_slots"] == server.n_slots + default_val = res.body["default_generation_settings"] + assert server.n_ctx is not None and server.n_slots is not None + assert default_val["n_ctx"] == server.n_ctx / server.n_slots + assert default_val["params"]["seed"] == server.seed def test_server_models(): @@ -33,6 +38,31 @@ def test_server_models(): assert len(res.body["data"]) == 1 assert res.body["data"][0]["id"] == server.model_alias + +def test_server_slots(): + global server + + # without slots endpoint enabled, this should return error + server.server_slots = False + server.start() + res = server.make_request("GET", "/slots") + assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED + assert "error" in res.body + server.stop() + + # with slots endpoint enabled, this should return slots info + server.server_slots = True + server.n_slots = 2 + server.start() + res = server.make_request("GET", "/slots") + assert res.status_code == 200 + assert len(res.body) == server.n_slots + assert server.n_ctx is not None and server.n_slots is not None + assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots + assert "params" in res.body[0] + assert res.body[0]["params"]["seed"] == server.seed + + def test_load_split_model(): global server server.model_hf_repo = "ggml-org/models" diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index f13c6c4ca..6573cc17f 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte ], }) assert res.status_code == 200 + assert "cmpl" in res.body["id"] # make sure the completion id has the expected format assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted @@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte "stream": True, }) content = "" + last_cmpl_id = None for data in res: choice = data["choices"][0] assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + if last_cmpl_id is None: + last_cmpl_id = data["id"] + assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream if choice["finish_reason"] in ["stop", "length"]: assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["completion_tokens"] == n_predicted diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index e17a05ff6..69215eaa4 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -64,6 +64,7 @@ class ServerProcess: server_embeddings: bool | None = False server_reranking: bool | None = False server_metrics: bool | None = False + server_slots: bool | None = False draft: int | None = None api_key: str | None = None response_format: str | None = None @@ -91,7 +92,6 @@ class ServerProcess: else: server_path = "../../../build/bin/llama-server" server_args = [ - "--slots", # requires to get slot status via /slots endpoint "--host", self.server_host, "--port", @@ -129,6 +129,8 @@ class ServerProcess: server_args.append("--reranking") if self.server_metrics: server_args.append("--metrics") + if self.server_slots: + server_args.append("--slots") if self.model_alias: server_args.extend(["--alias", self.model_alias]) if self.n_ctx: @@ -181,7 +183,7 @@ class ServerProcess: start_time = time.time() while time.time() - start_time < timeout_seconds: try: - response = self.make_request("GET", "/slots", headers={ + response = self.make_request("GET", "/health", headers={ "Authorization": f"Bearer {self.api_key}" if self.api_key else None }) if response.status_code == 200: @@ -224,7 +226,7 @@ class ServerProcess: result.headers = dict(response.headers) result.status_code = response.status_code result.body = response.json() if parse_body else None - print("Response from server", result.body) + print("Response from server", json.dumps(result.body, indent=2)) return result def make_stream_request( @@ -245,7 +247,7 @@ class ServerProcess: break elif line.startswith('data: '): data = json.loads(line[6:]) - print("Partial response from server", data) + print("Partial response from server", json.dumps(data, indent=2)) yield data diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c9fe7d966..8f545aea5 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -164,6 +164,9 @@ static std::vector tokenize_input_prompts(llama_context * ctx, con } else { throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } return result; } @@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse( const std::string & chat_template) { json llama_params; - llama_params["__oaicompat"] = true; - // Apply chat template to the list of messages llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); @@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) { {"content", content} }; } + +static json format_logit_bias(const std::vector & logit_bias) { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); + } + return data; +} + +static std::string safe_json_to_str(json data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +}