mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
server : (refactor) no more json in server_task input (#10691)
* server : (refactor) no more json in server_task input * add test for slots endpoint * add tests for /props and /slots * remove task inf_type * fix CI by adding safe_json_to_str * add "model_path" to /props * update readme
This commit is contained in:
parent
d9c3ba2b77
commit
3573fa8e7b
@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"total_slots": 1,
|
"total_slots": 1,
|
||||||
|
"model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
|
||||||
"chat_template": "..."
|
"chat_template": "..."
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
|
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
|
||||||
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
|
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
|
||||||
|
- `model_path` - the path to model file (same with `-m` argument)
|
||||||
- `chat_template` - the model's original Jinja2 prompt template
|
- `chat_template` - the model's original Jinja2 prompt template
|
||||||
|
|
||||||
### POST `/props`: Change server global properties.
|
### POST `/props`: Change server global properties.
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -22,7 +22,12 @@ def test_server_props():
|
|||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("GET", "/props")
|
res = server.make_request("GET", "/props")
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
assert ".gguf" in res.body["model_path"]
|
||||||
assert res.body["total_slots"] == server.n_slots
|
assert res.body["total_slots"] == server.n_slots
|
||||||
|
default_val = res.body["default_generation_settings"]
|
||||||
|
assert server.n_ctx is not None and server.n_slots is not None
|
||||||
|
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
|
||||||
|
assert default_val["params"]["seed"] == server.seed
|
||||||
|
|
||||||
|
|
||||||
def test_server_models():
|
def test_server_models():
|
||||||
@ -33,6 +38,31 @@ def test_server_models():
|
|||||||
assert len(res.body["data"]) == 1
|
assert len(res.body["data"]) == 1
|
||||||
assert res.body["data"][0]["id"] == server.model_alias
|
assert res.body["data"][0]["id"] == server.model_alias
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_slots():
|
||||||
|
global server
|
||||||
|
|
||||||
|
# without slots endpoint enabled, this should return error
|
||||||
|
server.server_slots = False
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("GET", "/slots")
|
||||||
|
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
|
||||||
|
assert "error" in res.body
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
# with slots endpoint enabled, this should return slots info
|
||||||
|
server.server_slots = True
|
||||||
|
server.n_slots = 2
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("GET", "/slots")
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body) == server.n_slots
|
||||||
|
assert server.n_ctx is not None and server.n_slots is not None
|
||||||
|
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
|
||||||
|
assert "params" in res.body[0]
|
||||||
|
assert res.body[0]["params"]["seed"] == server.seed
|
||||||
|
|
||||||
|
|
||||||
def test_load_split_model():
|
def test_load_split_model():
|
||||||
global server
|
global server
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/models"
|
||||||
|
@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||||||
],
|
],
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
|
||||||
assert res.body["model"] == model if model is not None else server.model_alias
|
assert res.body["model"] == model if model is not None else server.model_alias
|
||||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||||
@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||||||
"stream": True,
|
"stream": True,
|
||||||
})
|
})
|
||||||
content = ""
|
content = ""
|
||||||
|
last_cmpl_id = None
|
||||||
for data in res:
|
for data in res:
|
||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
|
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
|
||||||
|
if last_cmpl_id is None:
|
||||||
|
last_cmpl_id = data["id"]
|
||||||
|
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
|
||||||
if choice["finish_reason"] in ["stop", "length"]:
|
if choice["finish_reason"] in ["stop", "length"]:
|
||||||
assert data["usage"]["prompt_tokens"] == n_prompt
|
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||||
assert data["usage"]["completion_tokens"] == n_predicted
|
assert data["usage"]["completion_tokens"] == n_predicted
|
||||||
|
@ -64,6 +64,7 @@ class ServerProcess:
|
|||||||
server_embeddings: bool | None = False
|
server_embeddings: bool | None = False
|
||||||
server_reranking: bool | None = False
|
server_reranking: bool | None = False
|
||||||
server_metrics: bool | None = False
|
server_metrics: bool | None = False
|
||||||
|
server_slots: bool | None = False
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
response_format: str | None = None
|
response_format: str | None = None
|
||||||
@ -91,7 +92,6 @@ class ServerProcess:
|
|||||||
else:
|
else:
|
||||||
server_path = "../../../build/bin/llama-server"
|
server_path = "../../../build/bin/llama-server"
|
||||||
server_args = [
|
server_args = [
|
||||||
"--slots", # requires to get slot status via /slots endpoint
|
|
||||||
"--host",
|
"--host",
|
||||||
self.server_host,
|
self.server_host,
|
||||||
"--port",
|
"--port",
|
||||||
@ -129,6 +129,8 @@ class ServerProcess:
|
|||||||
server_args.append("--reranking")
|
server_args.append("--reranking")
|
||||||
if self.server_metrics:
|
if self.server_metrics:
|
||||||
server_args.append("--metrics")
|
server_args.append("--metrics")
|
||||||
|
if self.server_slots:
|
||||||
|
server_args.append("--slots")
|
||||||
if self.model_alias:
|
if self.model_alias:
|
||||||
server_args.extend(["--alias", self.model_alias])
|
server_args.extend(["--alias", self.model_alias])
|
||||||
if self.n_ctx:
|
if self.n_ctx:
|
||||||
@ -181,7 +183,7 @@ class ServerProcess:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while time.time() - start_time < timeout_seconds:
|
while time.time() - start_time < timeout_seconds:
|
||||||
try:
|
try:
|
||||||
response = self.make_request("GET", "/slots", headers={
|
response = self.make_request("GET", "/health", headers={
|
||||||
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
|
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
|
||||||
})
|
})
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
@ -224,7 +226,7 @@ class ServerProcess:
|
|||||||
result.headers = dict(response.headers)
|
result.headers = dict(response.headers)
|
||||||
result.status_code = response.status_code
|
result.status_code = response.status_code
|
||||||
result.body = response.json() if parse_body else None
|
result.body = response.json() if parse_body else None
|
||||||
print("Response from server", result.body)
|
print("Response from server", json.dumps(result.body, indent=2))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def make_stream_request(
|
def make_stream_request(
|
||||||
@ -245,7 +247,7 @@ class ServerProcess:
|
|||||||
break
|
break
|
||||||
elif line.startswith('data: '):
|
elif line.startswith('data: '):
|
||||||
data = json.loads(line[6:])
|
data = json.loads(line[6:])
|
||||||
print("Partial response from server", data)
|
print("Partial response from server", json.dumps(data, indent=2))
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
|
|||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
||||||
}
|
}
|
||||||
|
if (result.empty()) {
|
||||||
|
throw std::runtime_error("\"prompt\" must not be empty");
|
||||||
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse(
|
|||||||
const std::string & chat_template) {
|
const std::string & chat_template) {
|
||||||
json llama_params;
|
json llama_params;
|
||||||
|
|
||||||
llama_params["__oaicompat"] = true;
|
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
// Apply chat template to the list of messages
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
||||||
|
|
||||||
@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) {
|
|||||||
{"content", content}
|
{"content", content}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
|
||||||
|
json data = json::array();
|
||||||
|
for (const auto & lb : logit_bias) {
|
||||||
|
data.push_back(json{
|
||||||
|
{"bias", lb.bias},
|
||||||
|
{"token", lb.token},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string safe_json_to_str(json data) {
|
||||||
|
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user