mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
server : fix logprobs, make it OAI-compatible (#10783)
* server : fix logprobs, make it openai-compatible * update docs * add std::log * return pre-sampling p * sort before apply softmax * add comment * fix test * set p for sampled token * update docs * add --multi-token-probs * update docs * add `post_sampling_probs` option * update docs [no ci] * remove --multi-token-probs * "top_probs" with "post_sampling_probs" * resolve review comments * rename struct token_prob to prob_info * correct comment placement * fix setting prob for sampled token
This commit is contained in:
parent
a3c33b1dce
commit
57bb2c40cd
@ -343,6 +343,10 @@ node index.js
|
||||
|
||||
### POST `/completion`: Given a `prompt`, it returns the predicted completion.
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> This endpoint is **not** OAI-compatible
|
||||
|
||||
*Options:*
|
||||
|
||||
`prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true:
|
||||
@ -444,38 +448,68 @@ These words will not be included in the completion, so make sure to add them to
|
||||
|
||||
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
|
||||
|
||||
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
||||
|
||||
**Response format**
|
||||
|
||||
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||
|
||||
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "<the token generated by the model>",
|
||||
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:
|
||||
```json
|
||||
{
|
||||
"content": "<the generated completion text>",
|
||||
"tokens": [ generated token ids if requested ],
|
||||
...
|
||||
"probs": [
|
||||
{
|
||||
"prob": float,
|
||||
"tok_str": "<most likely token>"
|
||||
"id": <token id>,
|
||||
"logprob": float,
|
||||
"token": "<most likely token>",
|
||||
"bytes": [int, int, ...],
|
||||
"top_logprobs": [
|
||||
{
|
||||
"id": <token id>,
|
||||
"logprob": float,
|
||||
"token": "<token text>",
|
||||
"bytes": [int, int, ...],
|
||||
},
|
||||
{
|
||||
"prob": float,
|
||||
"tok_str": "<second most likely token>"
|
||||
"id": <token id>,
|
||||
"logprob": float,
|
||||
"token": "<token text>",
|
||||
"bytes": [int, int, ...],
|
||||
},
|
||||
...
|
||||
]
|
||||
},
|
||||
```
|
||||
|
||||
Notice that each `probs` is an array of length `n_probs`.
|
||||
},
|
||||
{
|
||||
"id": <token id>,
|
||||
"logprob": float,
|
||||
"token": "<most likely token>",
|
||||
"bytes": [int, int, ...],
|
||||
"top_logprobs": [
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
},
|
||||
```
|
||||
Please note that if `post_sampling_probs` is set to `true`:
|
||||
- `logprob` will be replaced with `prob`, with the value between 0.0 and 1.0
|
||||
- `top_logprobs` will be replaced with `top_probs`. Each element contains:
|
||||
- `id`: token ID
|
||||
- `token`: token in string
|
||||
- `bytes`: token in bytes
|
||||
- `prob`: token probability, with the value between 0.0 and 1.0
|
||||
- Number of elements in `top_probs` may be less than `n_probs`
|
||||
|
||||
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
|
||||
- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.
|
||||
- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
|
||||
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
|
||||
- `model`: The path to the model loaded with `-m`
|
||||
- `prompt`: The provided `prompt`
|
||||
- `model`: The model alias (for model path, please use `/props` endpoint)
|
||||
- `prompt`: The processed `prompt` (special tokens may be added)
|
||||
- `stop_type`: Indicating whether the completion has stopped. Possible values are:
|
||||
- `none`: Generating (not stopped)
|
||||
- `eos`: Stopped because it encountered the EOS token
|
||||
|
@ -93,6 +93,7 @@ struct slot_params {
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
bool ignore_eos = false;
|
||||
|
||||
struct common_params_sampling sampling;
|
||||
@ -151,6 +152,7 @@ struct slot_params {
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
{"speculative.p_min", speculative.p_min},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
};
|
||||
}
|
||||
};
|
||||
@ -231,6 +233,7 @@ struct server_task {
|
||||
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
||||
|
||||
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
||||
@ -436,36 +439,67 @@ inline std::string stop_type_to_str(stop_type type) {
|
||||
|
||||
struct completion_token_output {
|
||||
llama_token tok;
|
||||
float prob;
|
||||
std::string text_to_send;
|
||||
struct token_prob {
|
||||
struct prob_info {
|
||||
llama_token tok;
|
||||
std::string tok_str;
|
||||
std::string txt;
|
||||
float prob;
|
||||
};
|
||||
std::vector<token_prob> probs;
|
||||
std::vector<prob_info> probs;
|
||||
|
||||
json to_json() const {
|
||||
json to_json(bool post_sampling_probs) const {
|
||||
json probs_for_token = json::array();
|
||||
for (const auto & p : probs) {
|
||||
std::string txt(p.txt);
|
||||
txt.resize(validate_utf8(txt));
|
||||
probs_for_token.push_back(json {
|
||||
{"tok_str", p.tok_str},
|
||||
{"prob", p.prob},
|
||||
{"id", p.tok},
|
||||
{"token", txt},
|
||||
{"bytes", str_to_bytes(p.txt)},
|
||||
{
|
||||
post_sampling_probs ? "prob" : "logprob",
|
||||
post_sampling_probs ? p.prob : logarithm(p.prob)
|
||||
},
|
||||
});
|
||||
}
|
||||
return probs_for_token;
|
||||
}
|
||||
|
||||
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
|
||||
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
|
||||
json out = json::array();
|
||||
for (const auto & prob : probs) {
|
||||
const std::string tok_str = prob.text_to_send;
|
||||
for (const auto & p : probs) {
|
||||
std::string txt(p.text_to_send);
|
||||
txt.resize(validate_utf8(txt));
|
||||
out.push_back(json {
|
||||
{"content", tok_str},
|
||||
{"probs", prob.to_json()},
|
||||
{"id", p.tok},
|
||||
{"token", txt},
|
||||
{"bytes", str_to_bytes(p.text_to_send)},
|
||||
{
|
||||
post_sampling_probs ? "prob" : "logprob",
|
||||
post_sampling_probs ? p.prob : logarithm(p.prob)
|
||||
},
|
||||
{
|
||||
post_sampling_probs ? "top_probs" : "top_logprobs",
|
||||
p.to_json(post_sampling_probs)
|
||||
},
|
||||
});
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static float logarithm(float x) {
|
||||
// nlohmann::json converts -inf to null, so we need to prevent that
|
||||
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
|
||||
}
|
||||
|
||||
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
|
||||
std::vector<unsigned char> bytes;
|
||||
for (unsigned char c : str) {
|
||||
bytes.push_back(c);
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_final : server_task_result {
|
||||
@ -486,6 +520,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
std::string stopping_word;
|
||||
stop_type stop = STOP_TYPE_NONE;
|
||||
|
||||
bool post_sampling_probs;
|
||||
std::vector<completion_token_output> probs_output;
|
||||
|
||||
slot_params generation_params;
|
||||
@ -530,8 +565,8 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
{"tokens_cached", n_tokens_cached},
|
||||
{"timings", timings.to_json()},
|
||||
};
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
||||
if (!stream && !probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -542,19 +577,25 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json choices = json::array({json{
|
||||
json choice = json{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json {
|
||||
{"content", content},
|
||||
{"role", "assistant"}
|
||||
}
|
||||
}}});
|
||||
}};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
choice["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
json res = json {
|
||||
{"choices", choices},
|
||||
{"choices", json::array({choice})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion"},
|
||||
@ -584,12 +625,14 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json choices = json::array({json{{"finish_reason", finish_reason},
|
||||
json choice = json{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}});
|
||||
{"delta", json::object()}
|
||||
};
|
||||
|
||||
json ret = json {
|
||||
{"choices", choices},
|
||||
{"choices", json::array({choice})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
@ -618,7 +661,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
|
||||
std::vector<completion_token_output> probs_output;
|
||||
bool post_sampling_probs;
|
||||
completion_token_output prob_output;
|
||||
result_timings timings;
|
||||
|
||||
// OAI-compat fields
|
||||
@ -655,8 +699,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
if (timings.prompt_n > 0) {
|
||||
res.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
||||
if (!prob_output.probs.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -708,6 +752,14 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
}});
|
||||
}
|
||||
|
||||
GGML_ASSERT(choices.size() >= 1);
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
choices[0]["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
{"choices", choices},
|
||||
{"created", t},
|
||||
@ -1001,7 +1053,6 @@ struct server_slot {
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
size_t n_sent_token_probs = 0;
|
||||
|
||||
int64_t t_start_process_prompt;
|
||||
int64_t t_start_generation;
|
||||
@ -1023,7 +1074,6 @@ struct server_slot {
|
||||
stopping_word = "";
|
||||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
n_sent_token_probs = 0;
|
||||
task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
generated_tokens.clear();
|
||||
@ -1764,7 +1814,7 @@ struct server_context {
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
const std::string token_str = result.text_to_send;
|
||||
slot.sampled = result.tok;
|
||||
|
||||
slot.generated_text += token_str;
|
||||
@ -1774,26 +1824,7 @@ struct server_context {
|
||||
slot.has_next_token = true;
|
||||
|
||||
// check if there is incomplete UTF-8 character at the end
|
||||
bool incomplete = false;
|
||||
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
||||
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
|
||||
if ((c & 0xC0) == 0x80) {
|
||||
// continuation byte: 10xxxxxx
|
||||
continue;
|
||||
}
|
||||
if ((c & 0xE0) == 0xC0) {
|
||||
// 2-byte character: 110xxxxx ...
|
||||
incomplete = i < 2;
|
||||
} else if ((c & 0xF0) == 0xE0) {
|
||||
// 3-byte character: 1110xxxx ...
|
||||
incomplete = i < 3;
|
||||
} else if ((c & 0xF8) == 0xF0) {
|
||||
// 4-byte character: 11110xxx ...
|
||||
incomplete = i < 4;
|
||||
}
|
||||
// else 1-byte character or invalid byte
|
||||
break;
|
||||
}
|
||||
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
||||
|
||||
// search stop word and delete it
|
||||
if (!incomplete) {
|
||||
@ -1923,6 +1954,55 @@ struct server_context {
|
||||
return slot.has_next_token; // continue
|
||||
}
|
||||
|
||||
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
||||
size_t n_probs = slot.params.sampling.n_probs;
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
if (post_sampling) {
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||
const size_t max_probs = cur_p->size;
|
||||
|
||||
// set probability for sampled token
|
||||
for (size_t i = 0; i < max_probs; i++) {
|
||||
if (cur_p->data[i].id == result.tok) {
|
||||
result.prob = cur_p->data[i].p;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// set probability for top n_probs tokens
|
||||
result.probs.reserve(max_probs);
|
||||
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
||||
cur_p->data[i].p
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// TODO: optimize this with min-p optimization
|
||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||
|
||||
// set probability for sampled token
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
// set probability for sampled token
|
||||
if (cur[i].id == result.tok) {
|
||||
result.prob = cur[i].p;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// set probability for top n_probs tokens
|
||||
result.probs.reserve(n_probs);
|
||||
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
||||
result.probs.push_back({
|
||||
cur[i].id,
|
||||
common_detokenize(ctx, {cur[i].id}, special),
|
||||
cur[i].p
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||
send_error(task.id, error, type);
|
||||
}
|
||||
@ -1952,6 +2032,7 @@ struct server_context {
|
||||
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
@ -1961,17 +2042,7 @@ struct server_context {
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||
|
||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
||||
|
||||
std::vector<completion_token_output> probs_output;
|
||||
if (probs_pos < probs_stop_pos) {
|
||||
res->probs_output = std::vector<completion_token_output>(
|
||||
slot.generated_token_probs.begin() + probs_pos,
|
||||
slot.generated_token_probs.begin() + probs_stop_pos);
|
||||
}
|
||||
res->prob_output = tkn; // copy the token probs
|
||||
}
|
||||
|
||||
// populate timings if this is final response or timings_per_token is enabled
|
||||
@ -2000,6 +2071,7 @@ struct server_context {
|
||||
res->has_new_line = slot.has_new_line;
|
||||
res->stopping_word = slot.stopping_word;
|
||||
res->stop = slot.stop;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
@ -2796,7 +2868,9 @@ struct server_context {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
@ -2816,16 +2890,11 @@ struct server_context {
|
||||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||
|
||||
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
|
||||
auto tok_id = cur_p->data[i].id;
|
||||
result.probs.push_back({
|
||||
tok_id,
|
||||
tokens_to_output_formatted_string(ctx, tok_id),
|
||||
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||
});
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
@ -2910,6 +2979,10 @@ struct server_context {
|
||||
completion_token_output result;
|
||||
|
||||
result.tok = ids[i];
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
||||
result.prob = 1.0f; // set later
|
||||
|
||||
// TODO: set result.probs
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
// release slot because of stop condition
|
||||
|
@ -92,7 +92,6 @@ def test_chat_completion_with_openai_library():
|
||||
seed=42,
|
||||
temperature=0.8,
|
||||
)
|
||||
print(res)
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].message.content is not None
|
||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
@ -163,3 +162,64 @@ def test_chat_completion_with_timings_per_token():
|
||||
assert "predicted_per_second" in data["timings"]
|
||||
assert "predicted_n" in data["timings"]
|
||||
assert data["timings"]["predicted_n"] <= 10
|
||||
|
||||
|
||||
def test_logprobs():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=5,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
)
|
||||
output_text = res.choices[0].message.content
|
||||
aggregated_text = ''
|
||||
assert res.choices[0].logprobs is not None
|
||||
assert res.choices[0].logprobs.content is not None
|
||||
for token in res.choices[0].logprobs.content:
|
||||
aggregated_text += token.token
|
||||
assert token.logprob <= 0.0
|
||||
assert token.bytes is not None
|
||||
assert len(token.top_logprobs) > 0
|
||||
assert aggregated_text == output_text
|
||||
|
||||
|
||||
def test_logprobs_stream():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
max_tokens=5,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
stream=True,
|
||||
)
|
||||
output_text = ''
|
||||
aggregated_text = ''
|
||||
for data in res:
|
||||
choice = data.choices[0]
|
||||
if choice.finish_reason is None:
|
||||
if choice.delta.content:
|
||||
output_text += choice.delta.content
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.content is not None
|
||||
for token in choice.logprobs.content:
|
||||
aggregated_text += token.token
|
||||
assert token.logprob <= 0.0
|
||||
assert token.bytes is not None
|
||||
assert token.top_logprobs is not None
|
||||
assert len(token.top_logprobs) > 0
|
||||
assert aggregated_text == output_text
|
||||
|
@ -270,9 +270,68 @@ def test_n_probs():
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
assert "probs" in tok
|
||||
assert len(tok["probs"]) == 10
|
||||
for prob in tok["probs"]:
|
||||
assert "prob" in prob
|
||||
assert "tok_str" in prob
|
||||
assert 0.0 <= prob["prob"] <= 1.0
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "logprob" in tok and tok["logprob"] <= 0.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_logprobs"]) == 10
|
||||
for prob in tok["top_logprobs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "logprob" in prob and prob["logprob"] <= 0.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
|
||||
def test_n_probs_stream():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
"stream": True,
|
||||
})
|
||||
for data in res:
|
||||
if data["stop"] == False:
|
||||
assert "completion_probabilities" in data
|
||||
assert len(data["completion_probabilities"]) == 1
|
||||
for tok in data["completion_probabilities"]:
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "logprob" in tok and tok["logprob"] <= 0.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_logprobs"]) == 10
|
||||
for prob in tok["top_logprobs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "logprob" in prob and prob["logprob"] <= 0.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
|
||||
def test_n_probs_post_sampling():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
"post_sampling_probs": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_probs"]) == 10
|
||||
for prob in tok["top_probs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||
|
@ -50,6 +50,8 @@ def test_embedding_multiple():
|
||||
@pytest.mark.parametrize(
|
||||
"input,is_multi_prompt",
|
||||
[
|
||||
# do not crash on empty input
|
||||
("", False),
|
||||
# single prompt
|
||||
("string", False),
|
||||
([12, 34, 56], False),
|
||||
@ -103,6 +105,7 @@ def test_embedding_pooling_none_oai():
|
||||
|
||||
# /v1/embeddings does not support pooling type 'none'
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
def test_embedding_openai_library_single():
|
||||
|
@ -171,6 +171,36 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
|
||||
return result;
|
||||
}
|
||||
|
||||
// return the last index of character that can form a valid string
|
||||
// if the last character is potentially cut in half, return the index before the cut
|
||||
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
|
||||
static size_t validate_utf8(const std::string& text) {
|
||||
size_t len = text.size();
|
||||
if (len == 0) return 0;
|
||||
|
||||
// Check the last few bytes to see if a multi-byte character is cut off
|
||||
for (size_t i = 1; i <= 4 && i <= len; ++i) {
|
||||
unsigned char c = text[len - i];
|
||||
// Check for start of a multi-byte sequence from the end
|
||||
if ((c & 0xE0) == 0xC0) {
|
||||
// 2-byte character start: 110xxxxx
|
||||
// Needs at least 2 bytes
|
||||
if (i < 2) return len - i;
|
||||
} else if ((c & 0xF0) == 0xE0) {
|
||||
// 3-byte character start: 1110xxxx
|
||||
// Needs at least 3 bytes
|
||||
if (i < 3) return len - i;
|
||||
} else if ((c & 0xF8) == 0xF0) {
|
||||
// 4-byte character start: 11110xxx
|
||||
// Needs at least 4 bytes
|
||||
if (i < 4) return len - i;
|
||||
}
|
||||
}
|
||||
|
||||
// If no cut-off multi-byte character is found, return full length
|
||||
return len;
|
||||
}
|
||||
|
||||
//
|
||||
// template utils
|
||||
//
|
||||
@ -671,3 +701,33 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
|
||||
static std::string safe_json_to_str(json data) {
|
||||
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
}
|
||||
|
||||
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
||||
std::vector<llama_token_data> cur;
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
|
||||
// sort tokens by logits
|
||||
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
|
||||
// apply softmax
|
||||
float max_l = cur[0].logit;
|
||||
float cum_sum = 0.0f;
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
float p = expf(cur[i].logit - max_l);
|
||||
cur[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
cur[i].p /= cum_sum;
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user