mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
server : (refactor) no more json in server_task input (#10691)
* server : (refactor) no more json in server_task input * add test for slots endpoint * add tests for /props and /slots * remove task inf_type * fix CI by adding safe_json_to_str * add "model_path" to /props * update readme
This commit is contained in:
parent
d9c3ba2b77
commit
3573fa8e7b
@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make
|
||||
}
|
||||
},
|
||||
"total_slots": 1,
|
||||
"model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
|
||||
"chat_template": "..."
|
||||
}
|
||||
```
|
||||
|
||||
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
|
||||
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
|
||||
- `model_path` - the path to model file (same with `-m` argument)
|
||||
- `chat_template` - the model's original Jinja2 prompt template
|
||||
|
||||
### POST `/props`: Change server global properties.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -22,7 +22,12 @@ def test_server_props():
|
||||
server.start()
|
||||
res = server.make_request("GET", "/props")
|
||||
assert res.status_code == 200
|
||||
assert ".gguf" in res.body["model_path"]
|
||||
assert res.body["total_slots"] == server.n_slots
|
||||
default_val = res.body["default_generation_settings"]
|
||||
assert server.n_ctx is not None and server.n_slots is not None
|
||||
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
|
||||
assert default_val["params"]["seed"] == server.seed
|
||||
|
||||
|
||||
def test_server_models():
|
||||
@ -33,6 +38,31 @@ def test_server_models():
|
||||
assert len(res.body["data"]) == 1
|
||||
assert res.body["data"][0]["id"] == server.model_alias
|
||||
|
||||
|
||||
def test_server_slots():
|
||||
global server
|
||||
|
||||
# without slots endpoint enabled, this should return error
|
||||
server.server_slots = False
|
||||
server.start()
|
||||
res = server.make_request("GET", "/slots")
|
||||
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
|
||||
assert "error" in res.body
|
||||
server.stop()
|
||||
|
||||
# with slots endpoint enabled, this should return slots info
|
||||
server.server_slots = True
|
||||
server.n_slots = 2
|
||||
server.start()
|
||||
res = server.make_request("GET", "/slots")
|
||||
assert res.status_code == 200
|
||||
assert len(res.body) == server.n_slots
|
||||
assert server.n_ctx is not None and server.n_slots is not None
|
||||
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
|
||||
assert "params" in res.body[0]
|
||||
assert res.body[0]["params"]["seed"] == server.seed
|
||||
|
||||
|
||||
def test_load_split_model():
|
||||
global server
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
|
@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
|
||||
assert res.body["model"] == model if model is not None else server.model_alias
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
||||
"stream": True,
|
||||
})
|
||||
content = ""
|
||||
last_cmpl_id = None
|
||||
for data in res:
|
||||
choice = data["choices"][0]
|
||||
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
|
||||
if last_cmpl_id is None:
|
||||
last_cmpl_id = data["id"]
|
||||
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
|
||||
if choice["finish_reason"] in ["stop", "length"]:
|
||||
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||
assert data["usage"]["completion_tokens"] == n_predicted
|
||||
|
@ -64,6 +64,7 @@ class ServerProcess:
|
||||
server_embeddings: bool | None = False
|
||||
server_reranking: bool | None = False
|
||||
server_metrics: bool | None = False
|
||||
server_slots: bool | None = False
|
||||
draft: int | None = None
|
||||
api_key: str | None = None
|
||||
response_format: str | None = None
|
||||
@ -91,7 +92,6 @@ class ServerProcess:
|
||||
else:
|
||||
server_path = "../../../build/bin/llama-server"
|
||||
server_args = [
|
||||
"--slots", # requires to get slot status via /slots endpoint
|
||||
"--host",
|
||||
self.server_host,
|
||||
"--port",
|
||||
@ -129,6 +129,8 @@ class ServerProcess:
|
||||
server_args.append("--reranking")
|
||||
if self.server_metrics:
|
||||
server_args.append("--metrics")
|
||||
if self.server_slots:
|
||||
server_args.append("--slots")
|
||||
if self.model_alias:
|
||||
server_args.extend(["--alias", self.model_alias])
|
||||
if self.n_ctx:
|
||||
@ -181,7 +183,7 @@ class ServerProcess:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
try:
|
||||
response = self.make_request("GET", "/slots", headers={
|
||||
response = self.make_request("GET", "/health", headers={
|
||||
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
|
||||
})
|
||||
if response.status_code == 200:
|
||||
@ -224,7 +226,7 @@ class ServerProcess:
|
||||
result.headers = dict(response.headers)
|
||||
result.status_code = response.status_code
|
||||
result.body = response.json() if parse_body else None
|
||||
print("Response from server", result.body)
|
||||
print("Response from server", json.dumps(result.body, indent=2))
|
||||
return result
|
||||
|
||||
def make_stream_request(
|
||||
@ -245,7 +247,7 @@ class ServerProcess:
|
||||
break
|
||||
elif line.startswith('data: '):
|
||||
data = json.loads(line[6:])
|
||||
print("Partial response from server", data)
|
||||
print("Partial response from server", json.dumps(data, indent=2))
|
||||
yield data
|
||||
|
||||
|
||||
|
@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
|
||||
} else {
|
||||
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
||||
}
|
||||
if (result.empty()) {
|
||||
throw std::runtime_error("\"prompt\" must not be empty");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse(
|
||||
const std::string & chat_template) {
|
||||
json llama_params;
|
||||
|
||||
llama_params["__oaicompat"] = true;
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
||||
|
||||
@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) {
|
||||
{"content", content}
|
||||
};
|
||||
}
|
||||
|
||||
static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
|
||||
json data = json::array();
|
||||
for (const auto & lb : logit_bias) {
|
||||
data.push_back(json{
|
||||
{"bias", lb.bias},
|
||||
{"token", lb.token},
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
static std::string safe_json_to_str(json data) {
|
||||
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user