diff --git a/common/common.h b/common/common.h index 9b1508a15..0373fd3ea 100644 --- a/common/common.h +++ b/common/common.h @@ -133,6 +133,7 @@ struct common_params_sampling { bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; bool no_perf = false; // disable performance metrics + bool timing_per_token = false; std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/examples/server/README.md b/examples/server/README.md index 877768c8b..45ffb547f 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -416,6 +416,8 @@ node index.js `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values. + `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` + **Response format** - Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1c765f0ea..8eca14b86 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -177,6 +177,8 @@ struct server_slot { bool stopped_word = false; bool stopped_limit = false; + bool timings_per_token = false; + bool oaicompat = false; std::string oaicompat_model; @@ -882,6 +884,8 @@ struct server_context { slot.oaicompat_model = ""; } + slot.timings_per_token = json_value(data, "timings_per_token", false); + slot.params.stream = json_value(data, "stream", false); slot.params.cache_prompt = json_value(data, "cache_prompt", true); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); @@ -1279,6 +1283,7 @@ struct server_context { {"speculative.n_max", slot.params.speculative.n_max}, {"speculative.n_min", slot.params.speculative.n_min}, {"speculative.p_min", slot.params.speculative.p_min}, + {"timings_per_token", slot.timings_per_token}, }; } @@ -1336,6 +1341,10 @@ struct server_context { res.data["model"] = slot.oaicompat_model; } + if (slot.timings_per_token) { + res.data["timings"] = slot.get_formated_timings(); + } + queue_results.send(res); } @@ -2274,12 +2283,17 @@ struct server_context { common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + if (slot.n_decoded == 1) { - slot.t_start_generation = ggml_time_us(); + slot.t_start_generation = t_current; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); } + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; + completion_token_output result; result.tok = id; diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 1048d6fca..8a439f9ef 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -146,3 +146,20 @@ def test_invalid_chat_completion_req(messages): }) assert res.status_code == 400 or res.status_code == 500 assert "error" in res.body + + +def test_chat_completion_with_timings_per_token(): + global server + server.start() + res = server.make_stream_request("POST", "/chat/completions", data={ + "max_tokens": 10, + "messages": [{"role": "user", "content": "test"}], + "stream": True, + "timings_per_token": True, + }) + for data in res: + assert "timings" in data + assert "prompt_per_second" in data["timings"] + assert "predicted_per_second" in data["timings"] + assert "predicted_n" in data["timings"] + assert data["timings"]["predicted_n"] <= 10 diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 1665e9dc3..e4451532c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -650,6 +650,10 @@ static json format_final_response_oaicompat(const json & request, const json & r res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); } + if (result.contains("timings")) { + res.push_back({"timings", json_value(result, "timings", json::object())}); + } + return res; } @@ -740,6 +744,11 @@ static std::vector format_partial_response_oaicompat(const json & result, {"model", modelname}, {"object", "chat.completion.chunk"} }; + + if (result.contains("timings")) { + ret.push_back({"timings", json_value(result, "timings", json::object())}); + } + if (!finish_reason.empty()) { int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);