server : fix logprobs, make it OAI-compatible (#10783)

* server : fix logprobs, make it openai-compatible

* update docs

* add std::log

* return pre-sampling p

* sort before apply softmax

* add comment

* fix test

* set p for sampled token

* update docs

* add --multi-token-probs

* update docs

* add `post_sampling_probs` option

* update docs [no ci]

* remove --multi-token-probs

* "top_probs" with "post_sampling_probs"

* resolve review comments

* rename struct token_prob to prob_info

* correct comment placement

* fix setting prob for sampled token
This commit is contained in:
Xuan Son Nguyen 2024-12-19 15:40:08 +01:00 committed by GitHub
parent a3c33b1dce
commit 57bb2c40cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 396 additions and 107 deletions

View File

@ -343,6 +343,10 @@ node index.js
### POST `/completion`: Given a `prompt`, it returns the predicted completion. ### POST `/completion`: Given a `prompt`, it returns the predicted completion.
> [!IMPORTANT]
>
> This endpoint is **not** OAI-compatible
*Options:* *Options:*
`prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true: `prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true:
@ -444,38 +448,68 @@ These words will not be included in the completion, so make sure to add them to
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
**Response format** **Response format**
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure: - `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:
```json ```json
{ {
"content": "<the token generated by the model>", "content": "<the generated completion text>",
"tokens": [ generated token ids if requested ], "tokens": [ generated token ids if requested ],
...
"probs": [ "probs": [
{ {
"prob": float, "id": <token id>,
"tok_str": "<most likely token>" "logprob": float,
"token": "<most likely token>",
"bytes": [int, int, ...],
"top_logprobs": [
{
"id": <token id>,
"logprob": float,
"token": "<token text>",
"bytes": [int, int, ...],
}, },
{ {
"prob": float, "id": <token id>,
"tok_str": "<second most likely token>" "logprob": float,
"token": "<token text>",
"bytes": [int, int, ...],
},
...
]
},
{
"id": <token id>,
"logprob": float,
"token": "<most likely token>",
"bytes": [int, int, ...],
"top_logprobs": [
...
]
}, },
... ...
] ]
}, },
``` ```
Please note that if `post_sampling_probs` is set to `true`:
Notice that each `probs` is an array of length `n_probs`. - `logprob` will be replaced with `prob`, with the value between 0.0 and 1.0
- `top_logprobs` will be replaced with `top_probs`. Each element contains:
- `id`: token ID
- `token`: token in string
- `bytes`: token in bytes
- `prob`: token probability, with the value between 0.0 and 1.0
- Number of elements in `top_probs` may be less than `n_probs`
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request. - `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.
- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
- `model`: The path to the model loaded with `-m` - `model`: The model alias (for model path, please use `/props` endpoint)
- `prompt`: The provided `prompt` - `prompt`: The processed `prompt` (special tokens may be added)
- `stop_type`: Indicating whether the completion has stopped. Possible values are: - `stop_type`: Indicating whether the completion has stopped. Possible values are:
- `none`: Generating (not stopped) - `none`: Generating (not stopped)
- `eos`: Stopped because it encountered the EOS token - `eos`: Stopped because it encountered the EOS token

View File

@ -93,6 +93,7 @@ struct slot_params {
std::vector<std::string> antiprompt; std::vector<std::string> antiprompt;
bool timings_per_token = false; bool timings_per_token = false;
bool post_sampling_probs = false;
bool ignore_eos = false; bool ignore_eos = false;
struct common_params_sampling sampling; struct common_params_sampling sampling;
@ -151,6 +152,7 @@ struct slot_params {
{"speculative.n_min", speculative.n_min}, {"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min}, {"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token}, {"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
}; };
} }
}; };
@ -231,6 +233,7 @@ struct server_task {
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
@ -436,36 +439,67 @@ inline std::string stop_type_to_str(stop_type type) {
struct completion_token_output { struct completion_token_output {
llama_token tok; llama_token tok;
float prob;
std::string text_to_send; std::string text_to_send;
struct token_prob { struct prob_info {
llama_token tok; llama_token tok;
std::string tok_str; std::string txt;
float prob; float prob;
}; };
std::vector<token_prob> probs; std::vector<prob_info> probs;
json to_json() const { json to_json(bool post_sampling_probs) const {
json probs_for_token = json::array(); json probs_for_token = json::array();
for (const auto & p : probs) { for (const auto & p : probs) {
std::string txt(p.txt);
txt.resize(validate_utf8(txt));
probs_for_token.push_back(json { probs_for_token.push_back(json {
{"tok_str", p.tok_str}, {"id", p.tok},
{"prob", p.prob}, {"token", txt},
{"bytes", str_to_bytes(p.txt)},
{
post_sampling_probs ? "prob" : "logprob",
post_sampling_probs ? p.prob : logarithm(p.prob)
},
}); });
} }
return probs_for_token; return probs_for_token;
} }
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) { static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
json out = json::array(); json out = json::array();
for (const auto & prob : probs) { for (const auto & p : probs) {
const std::string tok_str = prob.text_to_send; std::string txt(p.text_to_send);
txt.resize(validate_utf8(txt));
out.push_back(json { out.push_back(json {
{"content", tok_str}, {"id", p.tok},
{"probs", prob.to_json()}, {"token", txt},
{"bytes", str_to_bytes(p.text_to_send)},
{
post_sampling_probs ? "prob" : "logprob",
post_sampling_probs ? p.prob : logarithm(p.prob)
},
{
post_sampling_probs ? "top_probs" : "top_logprobs",
p.to_json(post_sampling_probs)
},
}); });
} }
return out; return out;
} }
static float logarithm(float x) {
// nlohmann::json converts -inf to null, so we need to prevent that
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
}
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
std::vector<unsigned char> bytes;
for (unsigned char c : str) {
bytes.push_back(c);
}
return bytes;
}
}; };
struct server_task_result_cmpl_final : server_task_result { struct server_task_result_cmpl_final : server_task_result {
@ -486,6 +520,7 @@ struct server_task_result_cmpl_final : server_task_result {
std::string stopping_word; std::string stopping_word;
stop_type stop = STOP_TYPE_NONE; stop_type stop = STOP_TYPE_NONE;
bool post_sampling_probs;
std::vector<completion_token_output> probs_output; std::vector<completion_token_output> probs_output;
slot_params generation_params; slot_params generation_params;
@ -530,8 +565,8 @@ struct server_task_result_cmpl_final : server_task_result {
{"tokens_cached", n_tokens_cached}, {"tokens_cached", n_tokens_cached},
{"timings", timings.to_json()}, {"timings", timings.to_json()},
}; };
if (!probs_output.empty()) { if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
} }
return res; return res;
} }
@ -542,19 +577,25 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop"; finish_reason = "stop";
} }
json choices = json::array({json{ json choice = json{
{"finish_reason", finish_reason}, {"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"message", json { {"message", json {
{"content", content}, {"content", content},
{"role", "assistant"} {"role", "assistant"}
} }
}}}); }};
if (!stream && probs_output.size() > 0) {
choice["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
};
}
std::time_t t = std::time(0); std::time_t t = std::time(0);
json res = json { json res = json {
{"choices", choices}, {"choices", json::array({choice})},
{"created", t}, {"created", t},
{"model", oaicompat_model}, {"model", oaicompat_model},
{"object", "chat.completion"}, {"object", "chat.completion"},
@ -584,12 +625,14 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop"; finish_reason = "stop";
} }
json choices = json::array({json{{"finish_reason", finish_reason}, json choice = json{
{"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"delta", json::object()}}}); {"delta", json::object()}
};
json ret = json { json ret = json {
{"choices", choices}, {"choices", json::array({choice})},
{"created", t}, {"created", t},
{"id", oaicompat_cmpl_id}, {"id", oaicompat_cmpl_id},
{"model", oaicompat_model}, {"model", oaicompat_model},
@ -618,7 +661,8 @@ struct server_task_result_cmpl_partial : server_task_result {
int32_t n_decoded; int32_t n_decoded;
int32_t n_prompt_tokens; int32_t n_prompt_tokens;
std::vector<completion_token_output> probs_output; bool post_sampling_probs;
completion_token_output prob_output;
result_timings timings; result_timings timings;
// OAI-compat fields // OAI-compat fields
@ -655,8 +699,8 @@ struct server_task_result_cmpl_partial : server_task_result {
if (timings.prompt_n > 0) { if (timings.prompt_n > 0) {
res.push_back({"timings", timings.to_json()}); res.push_back({"timings", timings.to_json()});
} }
if (!probs_output.empty()) { if (!prob_output.probs.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
} }
return res; return res;
} }
@ -708,6 +752,14 @@ struct server_task_result_cmpl_partial : server_task_result {
}}); }});
} }
GGML_ASSERT(choices.size() >= 1);
if (prob_output.probs.size() > 0) {
choices[0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
json ret = json { json ret = json {
{"choices", choices}, {"choices", choices},
{"created", t}, {"created", t},
@ -1001,7 +1053,6 @@ struct server_slot {
// stats // stats
size_t n_sent_text = 0; // number of sent text character size_t n_sent_text = 0; // number of sent text character
size_t n_sent_token_probs = 0;
int64_t t_start_process_prompt; int64_t t_start_process_prompt;
int64_t t_start_generation; int64_t t_start_generation;
@ -1023,7 +1074,6 @@ struct server_slot {
stopping_word = ""; stopping_word = "";
n_past = 0; n_past = 0;
n_sent_text = 0; n_sent_text = 0;
n_sent_token_probs = 0;
task_type = SERVER_TASK_TYPE_COMPLETION; task_type = SERVER_TASK_TYPE_COMPLETION;
generated_tokens.clear(); generated_tokens.clear();
@ -1764,7 +1814,7 @@ struct server_context {
bool process_token(completion_token_output & result, server_slot & slot) { bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling // remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special); const std::string token_str = result.text_to_send;
slot.sampled = result.tok; slot.sampled = result.tok;
slot.generated_text += token_str; slot.generated_text += token_str;
@ -1774,26 +1824,7 @@ struct server_context {
slot.has_next_token = true; slot.has_next_token = true;
// check if there is incomplete UTF-8 character at the end // check if there is incomplete UTF-8 character at the end
bool incomplete = false; bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
if ((c & 0xC0) == 0x80) {
// continuation byte: 10xxxxxx
continue;
}
if ((c & 0xE0) == 0xC0) {
// 2-byte character: 110xxxxx ...
incomplete = i < 2;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character: 1110xxxx ...
incomplete = i < 3;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character: 11110xxx ...
incomplete = i < 4;
}
// else 1-byte character or invalid byte
break;
}
// search stop word and delete it // search stop word and delete it
if (!incomplete) { if (!incomplete) {
@ -1923,6 +1954,55 @@ struct server_context {
return slot.has_next_token; // continue return slot.has_next_token; // continue
} }
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size;
// set probability for sampled token
for (size_t i = 0; i < max_probs; i++) {
if (cur_p->data[i].id == result.tok) {
result.prob = cur_p->data[i].p;
break;
}
}
// set probability for top n_probs tokens
result.probs.reserve(max_probs);
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
result.probs.push_back({
cur_p->data[i].id,
common_detokenize(ctx, {cur_p->data[i].id}, special),
cur_p->data[i].p
});
}
} else {
// TODO: optimize this with min-p optimization
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
// set probability for sampled token
for (size_t i = 0; i < n_vocab; i++) {
// set probability for sampled token
if (cur[i].id == result.tok) {
result.prob = cur[i].p;
break;
}
}
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
result.probs.push_back({
cur[i].id,
common_detokenize(ctx, {cur[i].id}, special),
cur[i].p
});
}
}
}
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(task.id, error, type); send_error(task.id, error, type);
} }
@ -1952,6 +2032,7 @@ struct server_context {
res->n_decoded = slot.n_decoded; res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens; res->n_prompt_tokens = slot.n_prompt_tokens;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->verbose = slot.params.verbose; res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
@ -1961,17 +2042,7 @@ struct server_context {
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); res->prob_output = tkn; // copy the token probs
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
std::vector<completion_token_output> probs_output;
if (probs_pos < probs_stop_pos) {
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin() + probs_pos,
slot.generated_token_probs.begin() + probs_stop_pos);
}
} }
// populate timings if this is final response or timings_per_token is enabled // populate timings if this is final response or timings_per_token is enabled
@ -2000,6 +2071,7 @@ struct server_context {
res->has_new_line = slot.has_new_line; res->has_new_line = slot.has_new_line;
res->stopping_word = slot.stopping_word; res->stopping_word = slot.stopping_word;
res->stop = slot.stop; res->stop = slot.stop;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->verbose = slot.params.verbose; res->verbose = slot.params.verbose;
res->stream = slot.params.stream; res->stream = slot.params.stream;
@ -2796,7 +2868,9 @@ struct server_context {
continue; // continue loop of slots continue; // continue loop of slots
} }
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i); const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
slot.i_batch = -1; slot.i_batch = -1;
@ -2816,16 +2890,11 @@ struct server_context {
completion_token_output result; completion_token_output result;
result.tok = id; result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
const auto * cur_p = common_sampler_get_candidates(slot.smpl); if (slot.params.sampling.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
auto tok_id = cur_p->data[i].id;
result.probs.push_back({
tok_id,
tokens_to_output_formatted_string(ctx, tok_id),
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
} }
if (!process_token(result, slot)) { if (!process_token(result, slot)) {
@ -2910,6 +2979,10 @@ struct server_context {
completion_token_output result; completion_token_output result;
result.tok = ids[i]; result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.prob = 1.0f; // set later
// TODO: set result.probs
if (!process_token(result, slot)) { if (!process_token(result, slot)) {
// release slot because of stop condition // release slot because of stop condition

View File

@ -92,7 +92,6 @@ def test_chat_completion_with_openai_library():
seed=42, seed=42,
temperature=0.8, temperature=0.8,
) )
print(res)
assert res.choices[0].finish_reason == "length" assert res.choices[0].finish_reason == "length"
assert res.choices[0].message.content is not None assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content) assert match_regex("(Suddenly)+", res.choices[0].message.content)
@ -163,3 +162,64 @@ def test_chat_completion_with_timings_per_token():
assert "predicted_per_second" in data["timings"] assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"] assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10 assert data["timings"]["predicted_n"] <= 10
def test_logprobs():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
)
output_text = res.choices[0].message.content
aggregated_text = ''
assert res.choices[0].logprobs is not None
assert res.choices[0].logprobs.content is not None
for token in res.choices[0].logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
def test_logprobs_stream():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
stream=True,
)
output_text = ''
aggregated_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text

View File

@ -270,9 +270,68 @@ def test_n_probs():
assert "completion_probabilities" in res.body assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5 assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]: for tok in res.body["completion_probabilities"]:
assert "probs" in tok assert "id" in tok and tok["id"] > 0
assert len(tok["probs"]) == 10 assert "token" in tok and type(tok["token"]) == str
for prob in tok["probs"]: assert "logprob" in tok and tok["logprob"] <= 0.0
assert "prob" in prob assert "bytes" in tok and type(tok["bytes"]) == list
assert "tok_str" in prob assert len(tok["top_logprobs"]) == 10
assert 0.0 <= prob["prob"] <= 1.0 for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and prob["logprob"] <= 0.0
assert "bytes" in prob and type(prob["bytes"]) == list
def test_n_probs_stream():
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
"stream": True,
})
for data in res:
if data["stop"] == False:
assert "completion_probabilities" in data
assert len(data["completion_probabilities"]) == 1
for tok in data["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "logprob" in tok and tok["logprob"] <= 0.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and prob["logprob"] <= 0.0
assert "bytes" in prob and type(prob["bytes"]) == list
def test_n_probs_post_sampling():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
"post_sampling_probs": True,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_probs"]) == 10
for prob in tok["top_probs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
assert "bytes" in prob and type(prob["bytes"]) == list
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])

View File

@ -50,6 +50,8 @@ def test_embedding_multiple():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input,is_multi_prompt", "input,is_multi_prompt",
[ [
# do not crash on empty input
("", False),
# single prompt # single prompt
("string", False), ("string", False),
([12, 34, 56], False), ([12, 34, 56], False),
@ -103,6 +105,7 @@ def test_embedding_pooling_none_oai():
# /v1/embeddings does not support pooling type 'none' # /v1/embeddings does not support pooling type 'none'
assert res.status_code == 400 assert res.status_code == 400
assert "error" in res.body
def test_embedding_openai_library_single(): def test_embedding_openai_library_single():

View File

@ -171,6 +171,36 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
return result; return result;
} }
// return the last index of character that can form a valid string
// if the last character is potentially cut in half, return the index before the cut
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
static size_t validate_utf8(const std::string& text) {
size_t len = text.size();
if (len == 0) return 0;
// Check the last few bytes to see if a multi-byte character is cut off
for (size_t i = 1; i <= 4 && i <= len; ++i) {
unsigned char c = text[len - i];
// Check for start of a multi-byte sequence from the end
if ((c & 0xE0) == 0xC0) {
// 2-byte character start: 110xxxxx
// Needs at least 2 bytes
if (i < 2) return len - i;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character start: 1110xxxx
// Needs at least 3 bytes
if (i < 3) return len - i;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character start: 11110xxx
// Needs at least 4 bytes
if (i < 4) return len - i;
}
}
// If no cut-off multi-byte character is found, return full length
return len;
}
// //
// template utils // template utils
// //
@ -671,3 +701,33 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
static std::string safe_json_to_str(json data) { static std::string safe_json_to_str(json data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace); return data.dump(-1, ' ', false, json::error_handler_t::replace);
} }
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
std::vector<llama_token_data> cur;
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
// sort tokens by logits
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
// apply softmax
float max_l = cur[0].logit;
float cum_sum = 0.0f;
for (size_t i = 0; i < cur.size(); ++i) {
float p = expf(cur[i].logit - max_l);
cur[i].p = p;
cum_sum += p;
}
for (size_t i = 0; i < cur.size(); ++i) {
cur[i].p /= cum_sum;
}
return cur;
}