mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 12:33:06 +01:00
server : bring back info of final chunk in stream mode (#10722)
* server : bring back into to final chunk in stream mode * clarify a bit * traling space
This commit is contained in:
parent
06d70147e6
commit
e52522b869
@ -392,7 +392,7 @@ struct server_task_result {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
virtual bool is_stop() {
|
virtual bool is_stop() {
|
||||||
// only used by server_task_result_cmpl_partial
|
// only used by server_task_result_cmpl_*
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
virtual int get_index() {
|
virtual int get_index() {
|
||||||
@ -478,14 +478,20 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual bool is_stop() override {
|
||||||
|
return true; // in stream mode, final responses are considered stop
|
||||||
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat();
|
return oaicompat
|
||||||
|
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
|
||||||
|
: to_json_non_oaicompat();
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_non_oaicompat() {
|
json to_json_non_oaicompat() {
|
||||||
json res = json {
|
json res = json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"content", content},
|
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||||
{"id_slot", id_slot},
|
{"id_slot", id_slot},
|
||||||
{"stop", true},
|
{"stop", true},
|
||||||
{"model", oaicompat_model},
|
{"model", oaicompat_model},
|
||||||
@ -546,18 +552,46 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json to_json_oaicompat_chat_stream() {
|
||||||
|
std::time_t t = std::time(0);
|
||||||
|
std::string finish_reason = "length";
|
||||||
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
|
finish_reason = "stop";
|
||||||
|
}
|
||||||
|
|
||||||
|
json choices = json::array({json{{"finish_reason", finish_reason},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json::object()}}});
|
||||||
|
|
||||||
|
json ret = json {
|
||||||
|
{"choices", choices},
|
||||||
|
{"created", t},
|
||||||
|
{"id", oaicompat_cmpl_id},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"object", "chat.completion.chunk"},
|
||||||
|
{"usage", json {
|
||||||
|
{"completion_tokens", n_decoded},
|
||||||
|
{"prompt_tokens", n_prompt_tokens},
|
||||||
|
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||||
|
}},
|
||||||
|
};
|
||||||
|
|
||||||
|
if (timings.prompt_n >= 0) {
|
||||||
|
ret.push_back({"timings", timings.to_json()});
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_cmpl_partial : server_task_result {
|
struct server_task_result_cmpl_partial : server_task_result {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
std::string content;
|
std::string content;
|
||||||
|
|
||||||
bool truncated;
|
|
||||||
int32_t n_decoded;
|
int32_t n_decoded;
|
||||||
int32_t n_prompt_tokens;
|
int32_t n_prompt_tokens;
|
||||||
|
|
||||||
stop_type stop = STOP_TYPE_NONE;
|
|
||||||
|
|
||||||
std::vector<completion_token_output> probs_output;
|
std::vector<completion_token_output> probs_output;
|
||||||
result_timings timings;
|
result_timings timings;
|
||||||
|
|
||||||
@ -573,20 +607,19 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual bool is_stop() override {
|
virtual bool is_stop() override {
|
||||||
return stop != STOP_TYPE_NONE;
|
return false; // in stream mode, partial responses are not considered stop
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
if (oaicompat) {
|
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
||||||
return to_json_oaicompat();
|
}
|
||||||
}
|
|
||||||
bool is_stop = stop != STOP_TYPE_NONE;
|
json to_json_non_oaicompat() {
|
||||||
// non-OAI-compat JSON
|
// non-OAI-compat JSON
|
||||||
json res = json {
|
json res = json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"content", content},
|
{"content", content},
|
||||||
{"stop_type", stop_type_to_str(stop)},
|
{"stop", false},
|
||||||
{"stop", is_stop},
|
|
||||||
{"id_slot", id_slot},
|
{"id_slot", id_slot},
|
||||||
{"tokens_predicted", n_decoded},
|
{"tokens_predicted", n_decoded},
|
||||||
{"tokens_evaluated", n_prompt_tokens},
|
{"tokens_evaluated", n_prompt_tokens},
|
||||||
@ -598,72 +631,54 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||||||
if (!probs_output.empty()) {
|
if (!probs_output.empty()) {
|
||||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
||||||
}
|
}
|
||||||
if (is_stop) {
|
|
||||||
res.push_back({"truncated", truncated});
|
|
||||||
}
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_oaicompat() {
|
json to_json_oaicompat() {
|
||||||
bool first = n_decoded == 0;
|
bool first = n_decoded == 0;
|
||||||
|
|
||||||
std::string finish_reason;
|
|
||||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
||||||
finish_reason = "stop";
|
|
||||||
} else if (stop == STOP_TYPE_LIMIT) {
|
|
||||||
finish_reason = "length";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
|
|
||||||
json choices;
|
json choices;
|
||||||
|
|
||||||
if (!finish_reason.empty()) {
|
if (first) {
|
||||||
choices = json::array({json{{"finish_reason", finish_reason},
|
if (content.empty()) {
|
||||||
{"index", 0},
|
choices = json::array({json{{"finish_reason", nullptr},
|
||||||
{"delta", json::object()}}});
|
|
||||||
} else {
|
|
||||||
if (first) {
|
|
||||||
if (content.empty()) {
|
|
||||||
choices = json::array({json{{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json{{"role", "assistant"}}}}});
|
|
||||||
} else {
|
|
||||||
// We have to send this as two updates to conform to openai behavior
|
|
||||||
json initial_ret = json{{"choices", json::array({json{
|
|
||||||
{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json{
|
{"delta", json{{"role", "assistant"}}}}});
|
||||||
{"role", "assistant"}
|
|
||||||
}}}})},
|
|
||||||
{"created", t},
|
|
||||||
{"id", oaicompat_cmpl_id},
|
|
||||||
{"model", oaicompat_model},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
json second_ret = json{
|
|
||||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json{
|
|
||||||
{"content", content}}}
|
|
||||||
}})},
|
|
||||||
{"created", t},
|
|
||||||
{"id", oaicompat_cmpl_id},
|
|
||||||
{"model", oaicompat_model},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
return std::vector<json>({initial_ret, second_ret});
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
choices = json::array({json{
|
// We have to send this as two updates to conform to openai behavior
|
||||||
{"finish_reason", nullptr},
|
json initial_ret = json{{"choices", json::array({json{
|
||||||
{"index", 0},
|
{"finish_reason", nullptr},
|
||||||
{"delta",
|
{"index", 0},
|
||||||
json{
|
{"delta", json{
|
||||||
{"content", content},
|
{"role", "assistant"}
|
||||||
}},
|
}}}})},
|
||||||
}});
|
{"created", t},
|
||||||
|
{"id", oaicompat_cmpl_id},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"object", "chat.completion.chunk"}};
|
||||||
|
|
||||||
|
json second_ret = json{
|
||||||
|
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json{
|
||||||
|
{"content", content}}}
|
||||||
|
}})},
|
||||||
|
{"created", t},
|
||||||
|
{"id", oaicompat_cmpl_id},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"object", "chat.completion.chunk"}};
|
||||||
|
|
||||||
|
return std::vector<json>({initial_ret, second_ret});
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
choices = json::array({json{
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta",
|
||||||
|
json{
|
||||||
|
{"content", content},
|
||||||
|
}},
|
||||||
|
}});
|
||||||
}
|
}
|
||||||
|
|
||||||
json ret = json {
|
json ret = json {
|
||||||
@ -678,14 +693,6 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||||||
ret.push_back({"timings", timings.to_json()});
|
ret.push_back({"timings", timings.to_json()});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!finish_reason.empty()) {
|
|
||||||
ret.push_back({"usage", json {
|
|
||||||
{"completion_tokens", n_decoded},
|
|
||||||
{"prompt_tokens", n_prompt_tokens},
|
|
||||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
|
||||||
}});
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::vector<json>({ret});
|
return std::vector<json>({ret});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1888,12 +1895,9 @@ struct server_context {
|
|||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
res->content = tkn.text_to_send;
|
res->content = tkn.text_to_send;
|
||||||
|
|
||||||
res->truncated = slot.truncated;
|
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||||
|
|
||||||
res->stop = slot.stop;
|
|
||||||
|
|
||||||
res->verbose = slot.params.verbose;
|
res->verbose = slot.params.verbose;
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_chat = slot.params.oaicompat_chat;
|
res->oaicompat_chat = slot.params.oaicompat_chat;
|
||||||
@ -1924,12 +1928,6 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void send_final_response(server_slot & slot) {
|
void send_final_response(server_slot & slot) {
|
||||||
if (slot.params.stream) {
|
|
||||||
// if in stream mode, send the last partial response
|
|
||||||
send_partial_response(slot, {0, "", {}});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto res = std::make_unique<server_task_result_cmpl_final>();
|
auto res = std::make_unique<server_task_result_cmpl_final>();
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->id_slot = slot.id;
|
res->id_slot = slot.id;
|
||||||
@ -1948,6 +1946,7 @@ struct server_context {
|
|||||||
res->stop = slot.stop;
|
res->stop = slot.stop;
|
||||||
|
|
||||||
res->verbose = slot.params.verbose;
|
res->verbose = slot.params.verbose;
|
||||||
|
res->stream = slot.params.stream;
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_chat = slot.params.oaicompat_chat;
|
res->oaicompat_chat = slot.params.oaicompat_chat;
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
@ -2100,7 +2099,10 @@ struct server_context {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr);
|
GGML_ASSERT(
|
||||||
|
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||||
|
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||||
|
);
|
||||||
if (!result_handler(result)) {
|
if (!result_handler(result)) {
|
||||||
cancel_tasks(id_tasks);
|
cancel_tasks(id_tasks);
|
||||||
break;
|
break;
|
||||||
|
@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
|||||||
})
|
})
|
||||||
content = ""
|
content = ""
|
||||||
for data in res:
|
for data in res:
|
||||||
|
assert "stop" in data and type(data["stop"]) == bool
|
||||||
if data["stop"]:
|
if data["stop"]:
|
||||||
assert data["timings"]["prompt_n"] == n_prompt
|
assert data["timings"]["prompt_n"] == n_prompt
|
||||||
assert data["timings"]["predicted_n"] == n_predicted
|
assert data["timings"]["predicted_n"] == n_predicted
|
||||||
assert data["truncated"] == truncated
|
assert data["truncated"] == truncated
|
||||||
|
assert data["stop_type"] == "limit"
|
||||||
|
assert "generation_settings" in data
|
||||||
|
assert server.n_predict is not None
|
||||||
|
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
|
||||||
|
assert data["generation_settings"]["seed"] == server.seed
|
||||||
assert match_regex(re_content, content)
|
assert match_regex(re_content, content)
|
||||||
else:
|
else:
|
||||||
content += data["content"]
|
content += data["content"]
|
||||||
|
Loading…
Reference in New Issue
Block a user