mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-04 15:43:53 +01:00
llama_server_response_fields
This commit is contained in:
parent
081b29bd2a
commit
2e04ccf4e6
@ -91,6 +91,7 @@ struct slot_params {
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
std::vector<std::string> requested_fields;
|
||||
bool timings_per_token = false;
|
||||
bool ignore_eos = false;
|
||||
|
||||
@ -205,6 +206,7 @@ struct server_task {
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.requested_fields = json_value(data, "requested_fields", std::vector<std::string>());
|
||||
|
||||
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||
@ -482,6 +484,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
stop_type stop = STOP_TYPE_NONE;
|
||||
|
||||
std::vector<completion_token_output> probs_output;
|
||||
std::vector<std::string> requested_fields;
|
||||
|
||||
slot_params generation_params;
|
||||
|
||||
@ -527,7 +530,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
||||
}
|
||||
return res;
|
||||
return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res);
|
||||
}
|
||||
|
||||
json to_json_oaicompat_chat() {
|
||||
@ -1960,6 +1963,7 @@ struct server_context {
|
||||
res->content = slot.generated_text;
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||
res->requested_fields = slot.params.requested_fields;
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
|
@ -88,6 +88,33 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// get value by path(key1 / key2)
|
||||
static json json_get_nested_values(const std::vector<std::string>& paths, const json& js) {
|
||||
json result = json::object();
|
||||
|
||||
for (const std::string& path : paths) {
|
||||
json current = js;
|
||||
std::istringstream stream(path);
|
||||
std::string key;
|
||||
std::vector<std::string> keys;
|
||||
while (std::getline(stream, key, '/')) {
|
||||
keys.push_back(key);
|
||||
}
|
||||
bool valid_path = true;
|
||||
for (const std::string& k : keys) {
|
||||
if (valid_path && current.is_object() && current.contains(k)) {
|
||||
current = current[k];
|
||||
} else {
|
||||
valid_path = false;
|
||||
}
|
||||
}
|
||||
if (valid_path) {
|
||||
result[path] = current;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* this handles 2 cases:
|
||||
* - only string, example: "string"
|
||||
|
Loading…
Reference in New Issue
Block a user