From ce8784bdb153ff7794dde5a50b0ebfa51baa6171 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 8 Dec 2024 23:04:29 +0100 Subject: [PATCH] server : fix format_infill (#10724) * server : fix format_infill * fix * rename * update test * use another model * update test * update test * test_invalid_input_extra_req --- examples/server/server.cpp | 22 ++++++++++++++ examples/server/tests/unit/test_infill.py | 36 ++++++++++++++++++----- examples/server/tests/utils.py | 3 ++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1d9c0533d..47bfd6c4a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3484,6 +3484,11 @@ int main(int argc, char ** argv) { json data = json::parse(req.body); // validate input + if (data.contains("prompt") && !data.at("prompt").is_string()) { + // prompt is optional + res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + } + if (!data.contains("input_prefix")) { res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); } @@ -3493,9 +3498,11 @@ int main(int argc, char ** argv) { } if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + // input_extra is optional res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); return; } + json input_extra = json_value(data, "input_extra", json::array()); for (const auto & chunk : input_extra) { // { "text": string, "filename": string } @@ -3511,6 +3518,21 @@ int main(int argc, char ** argv) { } data["input_extra"] = input_extra; // default to empty array if it's not exist + std::string prompt = json_value(data, "prompt", std::string()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + data["prompt"] = format_infill( + ctx_server.ctx, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + ctx_server.params_base.n_batch, + ctx_server.params_base.n_predict, + ctx_server.slots[0].n_ctx, // TODO: there should be a better way + ctx_server.params_base.spm_infill, + tokenized_prompts[0] + ); + return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); }; diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py index 6a6d40a1c..ad4b8192a 100644 --- a/examples/server/tests/unit/test_infill.py +++ b/examples/server/tests/unit/test_infill.py @@ -13,28 +13,28 @@ def test_infill_without_input_extra(): global server server.start() res = server.make_request("POST", "/infill", data={ - "prompt": "Complete this", - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) + assert match_regex("(Ann|small|shiny)+", res.body["content"]) def test_infill_with_input_extra(): global server server.start() res = server.make_request("POST", "/infill", data={ - "prompt": "Complete this", "input_extra": [{ "filename": "llama.h", "text": "LLAMA_API int32_t llama_n_threads();\n" }], - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"]) + assert match_regex("(Dad|excited|park)+", res.body["content"]) @pytest.mark.parametrize("input_extra", [ @@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra): global server server.start() res = server.make_request("POST", "/infill", data={ - "prompt": "Complete this", "input_extra": [input_extra], - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_", + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", "input_suffix": "}\n", }) assert res.status_code == 400 assert "error" in res.body + + +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") +def test_with_qwen_model(): + global server + server.model_file = None + server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF" + server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf" + server.start(timeout_seconds=600) + res = server.make_request("POST", "/infill", data={ + "input_extra": [{ + "filename": "llama.h", + "text": "LLAMA_API int32_t llama_n_threads();\n" + }], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n" diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 69215eaa4..7c89b9cd3 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -371,3 +371,6 @@ def match_regex(regex: str, text: str) -> bool: ).search(text) is not None ) + +def is_slow_test_allowed(): + return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"