llama_server_response_fields

This commit is contained in:
nvrxq 2024-12-18 01:21:44 +03:00
parent 081b29bd2a
commit 2e04ccf4e6
2 changed files with 32 additions and 1 deletions

View File

@ -91,6 +91,7 @@ struct slot_params {
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<std::string> antiprompt;
std::vector<std::string> requested_fields;
bool timings_per_token = false;
bool ignore_eos = false;
@ -205,6 +206,7 @@ struct server_task {
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.requested_fields = json_value(data, "requested_fields", std::vector<std::string>());
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
@ -482,6 +484,7 @@ struct server_task_result_cmpl_final : server_task_result {
stop_type stop = STOP_TYPE_NONE;
std::vector<completion_token_output> probs_output;
std::vector<std::string> requested_fields;
slot_params generation_params;
@ -527,7 +530,7 @@ struct server_task_result_cmpl_final : server_task_result {
if (!probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
}
return res;
return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res);
}
json to_json_oaicompat_chat() {
@ -1960,6 +1963,7 @@ struct server_context {
res->content = slot.generated_text;
res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res->requested_fields = slot.params.requested_fields;
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;

View File

@ -88,6 +88,33 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
return false;
}
// get value by path(key1 / key2)
static json json_get_nested_values(const std::vector<std::string>& paths, const json& js) {
json result = json::object();
for (const std::string& path : paths) {
json current = js;
std::istringstream stream(path);
std::string key;
std::vector<std::string> keys;
while (std::getline(stream, key, '/')) {
keys.push_back(key);
}
bool valid_path = true;
for (const std::string& k : keys) {
if (valid_path && current.is_object() && current.contains(k)) {
current = current[k];
} else {
valid_path = false;
}
}
if (valid_path) {
result[path] = current;
}
}
return result;
}
/**
* this handles 2 cases:
* - only string, example: "string"