mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
server: allow filtering llama server response fields (#10940)
* llama_server_response_fields * llama_server_response_fields_fix_issues * params fixes * fix * clarify docs * change to "response_fields" --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
30caac3a68
commit
09fe2e7613
@ -450,6 +450,8 @@ These words will not be included in the completion, so make sure to add them to
|
|||||||
|
|
||||||
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
||||||
|
|
||||||
|
`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error.
|
||||||
|
|
||||||
**Response format**
|
**Response format**
|
||||||
|
|
||||||
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||||
|
@ -92,6 +92,7 @@ struct slot_params {
|
|||||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
|
std::vector<std::string> response_fields;
|
||||||
bool timings_per_token = false;
|
bool timings_per_token = false;
|
||||||
bool post_sampling_probs = false;
|
bool post_sampling_probs = false;
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
@ -209,6 +210,7 @@ struct server_task {
|
|||||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||||
|
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||||
|
|
||||||
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||||
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||||
@ -522,6 +524,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
|
|
||||||
bool post_sampling_probs;
|
bool post_sampling_probs;
|
||||||
std::vector<completion_token_output> probs_output;
|
std::vector<completion_token_output> probs_output;
|
||||||
|
std::vector<std::string> response_fields;
|
||||||
|
|
||||||
slot_params generation_params;
|
slot_params generation_params;
|
||||||
|
|
||||||
@ -568,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
if (!stream && !probs_output.empty()) {
|
if (!stream && !probs_output.empty()) {
|
||||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||||
}
|
}
|
||||||
return res;
|
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_oaicompat_chat() {
|
json to_json_oaicompat_chat() {
|
||||||
@ -2066,6 +2069,7 @@ struct server_context {
|
|||||||
res->tokens = slot.generated_tokens;
|
res->tokens = slot.generated_tokens;
|
||||||
res->timings = slot.get_timings();
|
res->timings = slot.get_timings();
|
||||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||||
|
res->response_fields = slot.params.response_fields;
|
||||||
|
|
||||||
res->truncated = slot.truncated;
|
res->truncated = slot.truncated;
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
|
@ -257,6 +257,40 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
|||||||
# assert match_regex(re_content, res.body["content"])
|
# assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"prompt,n_predict,response_fields",
|
||||||
|
[
|
||||||
|
("I believe the meaning of life is", 8, []),
|
||||||
|
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_response_fields(
|
||||||
|
prompt: str, n_predict: int, response_fields: list[str]
|
||||||
|
):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request(
|
||||||
|
"POST",
|
||||||
|
"/completion",
|
||||||
|
data={
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"prompt": prompt,
|
||||||
|
"response_fields": response_fields,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "content" in res.body
|
||||||
|
assert len(res.body["content"])
|
||||||
|
if len(response_fields):
|
||||||
|
assert res.body["generation_settings/n_predict"] == n_predict
|
||||||
|
assert res.body["prompt"] == "<s> " + prompt
|
||||||
|
assert isinstance(res.body["content"], str)
|
||||||
|
assert len(res.body) == len(response_fields)
|
||||||
|
else:
|
||||||
|
assert len(res.body)
|
||||||
|
assert "generation_settings" in res.body
|
||||||
|
|
||||||
|
|
||||||
def test_n_probs():
|
def test_n_probs():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
|
@ -90,6 +90,28 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get value by path(key1 / key2)
|
||||||
|
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
|
||||||
|
json result = json::object();
|
||||||
|
|
||||||
|
for (const std::string & path : paths) {
|
||||||
|
json current = js;
|
||||||
|
const auto keys = string_split<std::string>(path, /*separator*/ '/');
|
||||||
|
bool valid_path = true;
|
||||||
|
for (const std::string & k : keys) {
|
||||||
|
if (valid_path && current.is_object() && current.contains(k)) {
|
||||||
|
current = current[k];
|
||||||
|
} else {
|
||||||
|
valid_path = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (valid_path) {
|
||||||
|
result[path] = current;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this handles 2 cases:
|
* this handles 2 cases:
|
||||||
* - only string, example: "string"
|
* - only string, example: "string"
|
||||||
|
Loading…
Reference in New Issue
Block a user