mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
server : fix format_infill (#10724)
* server : fix format_infill * fix * rename * update test * use another model * update test * update test * test_invalid_input_extra_req
This commit is contained in:
parent
e52522b869
commit
ce8784bdb1
@ -3484,6 +3484,11 @@ int main(int argc, char ** argv) {
|
|||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
|
||||||
// validate input
|
// validate input
|
||||||
|
if (data.contains("prompt") && !data.at("prompt").is_string()) {
|
||||||
|
// prompt is optional
|
||||||
|
res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
if (!data.contains("input_prefix")) {
|
if (!data.contains("input_prefix")) {
|
||||||
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
}
|
}
|
||||||
@ -3493,9 +3498,11 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||||
|
// input_extra is optional
|
||||||
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json input_extra = json_value(data, "input_extra", json::array());
|
json input_extra = json_value(data, "input_extra", json::array());
|
||||||
for (const auto & chunk : input_extra) {
|
for (const auto & chunk : input_extra) {
|
||||||
// { "text": string, "filename": string }
|
// { "text": string, "filename": string }
|
||||||
@ -3511,6 +3518,21 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||||
|
|
||||||
|
std::string prompt = json_value(data, "prompt", std::string());
|
||||||
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
||||||
|
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||||
|
data["prompt"] = format_infill(
|
||||||
|
ctx_server.ctx,
|
||||||
|
data.at("input_prefix"),
|
||||||
|
data.at("input_suffix"),
|
||||||
|
data.at("input_extra"),
|
||||||
|
ctx_server.params_base.n_batch,
|
||||||
|
ctx_server.params_base.n_predict,
|
||||||
|
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
|
||||||
|
ctx_server.params_base.spm_infill,
|
||||||
|
tokenized_prompts[0]
|
||||||
|
);
|
||||||
|
|
||||||
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
|
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -13,28 +13,28 @@ def test_infill_without_input_extra():
|
|||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/infill", data={
|
res = server.make_request("POST", "/infill", data={
|
||||||
"prompt": "Complete this",
|
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
|
||||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
|
"prompt": " int n_threads = llama_",
|
||||||
"input_suffix": "}\n",
|
"input_suffix": "}\n",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
|
assert match_regex("(Ann|small|shiny)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
def test_infill_with_input_extra():
|
def test_infill_with_input_extra():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/infill", data={
|
res = server.make_request("POST", "/infill", data={
|
||||||
"prompt": "Complete this",
|
|
||||||
"input_extra": [{
|
"input_extra": [{
|
||||||
"filename": "llama.h",
|
"filename": "llama.h",
|
||||||
"text": "LLAMA_API int32_t llama_n_threads();\n"
|
"text": "LLAMA_API int32_t llama_n_threads();\n"
|
||||||
}],
|
}],
|
||||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
|
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
|
||||||
|
"prompt": " int n_threads = llama_",
|
||||||
"input_suffix": "}\n",
|
"input_suffix": "}\n",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
|
assert match_regex("(Dad|excited|park)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("input_extra", [
|
@pytest.mark.parametrize("input_extra", [
|
||||||
@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra):
|
|||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/infill", data={
|
res = server.make_request("POST", "/infill", data={
|
||||||
"prompt": "Complete this",
|
|
||||||
"input_extra": [input_extra],
|
"input_extra": [input_extra],
|
||||||
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
|
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
|
||||||
|
"prompt": " int n_threads = llama_",
|
||||||
"input_suffix": "}\n",
|
"input_suffix": "}\n",
|
||||||
})
|
})
|
||||||
assert res.status_code == 400
|
assert res.status_code == 400
|
||||||
assert "error" in res.body
|
assert "error" in res.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
|
||||||
|
def test_with_qwen_model():
|
||||||
|
global server
|
||||||
|
server.model_file = None
|
||||||
|
server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
|
||||||
|
server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
|
||||||
|
server.start(timeout_seconds=600)
|
||||||
|
res = server.make_request("POST", "/infill", data={
|
||||||
|
"input_extra": [{
|
||||||
|
"filename": "llama.h",
|
||||||
|
"text": "LLAMA_API int32_t llama_n_threads();\n"
|
||||||
|
}],
|
||||||
|
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
|
||||||
|
"prompt": " int n_threads = llama_",
|
||||||
|
"input_suffix": "}\n",
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"
|
||||||
|
@ -371,3 +371,6 @@ def match_regex(regex: str, text: str) -> bool:
|
|||||||
).search(text)
|
).search(text)
|
||||||
is not None
|
is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_slow_test_allowed():
|
||||||
|
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
|
||||||
|
Loading…
Reference in New Issue
Block a user