mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-09 09:48:16 +01:00
server : fill usage info in embeddings and rerank responses (#10852)
* server : fill usage info in embeddings response * server : fill usage info in reranking response
This commit is contained in:
parent
382bc7f2e8
commit
05c3a444b8
@ -719,14 +719,17 @@ struct server_task_result_embd : server_task_result {
|
|||||||
int index = 0;
|
int index = 0;
|
||||||
std::vector<float> embedding;
|
std::vector<float> embedding;
|
||||||
|
|
||||||
|
int32_t n_tokens;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return json {
|
return json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"embedding", embedding},
|
{"embedding", embedding},
|
||||||
|
{"tokens_evaluated", n_tokens},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -735,14 +738,17 @@ struct server_task_result_rerank : server_task_result {
|
|||||||
int index = 0;
|
int index = 0;
|
||||||
float score = -1e6;
|
float score = -1e6;
|
||||||
|
|
||||||
|
int32_t n_tokens;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return json {
|
return json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"score", score},
|
{"score", score},
|
||||||
|
{"tokens_evaluated", n_tokens},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1995,6 +2001,7 @@ struct server_context {
|
|||||||
auto res = std::make_unique<server_task_result_embd>();
|
auto res = std::make_unique<server_task_result_embd>();
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
|
res->n_tokens = slot.n_prompt_tokens;
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(model);
|
const int n_embd = llama_n_embd(model);
|
||||||
|
|
||||||
@ -2030,6 +2037,7 @@ struct server_context {
|
|||||||
auto res = std::make_unique<server_task_result_rerank>();
|
auto res = std::make_unique<server_task_result_rerank>();
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
|
res->n_tokens = slot.n_prompt_tokens;
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
|
@ -97,3 +97,33 @@ def test_same_prompt_give_same_result():
|
|||||||
vi = res.body['data'][i]['embedding']
|
vi = res.body['data'][i]['embedding']
|
||||||
for x, y in zip(v0, vi):
|
for x, y in zip(v0, vi):
|
||||||
assert abs(x - y) < EPSILON
|
assert abs(x - y) < EPSILON
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content,n_tokens",
|
||||||
|
[
|
||||||
|
("I believe the meaning of life is", 7),
|
||||||
|
("This is a test", 4),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_embedding_usage_single(content, n_tokens):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={"input": content})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||||
|
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_usage_multiple():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={
|
||||||
|
"input": [
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
"I believe the meaning of life is",
|
||||||
|
],
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||||
|
assert res.body['usage']['prompt_tokens'] == 2 * 7
|
||||||
|
@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents):
|
|||||||
})
|
})
|
||||||
assert res.status_code == 400
|
assert res.status_code == 400
|
||||||
assert "error" in res.body
|
assert "error" in res.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,doc1,doc2,n_tokens",
|
||||||
|
[
|
||||||
|
("Machine learning is", "A machine", "Learning is", 19),
|
||||||
|
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_rerank_usage(query, doc1, doc2, n_tokens):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("POST", "/rerank", data={
|
||||||
|
"query": query,
|
||||||
|
"documents": [
|
||||||
|
doc1,
|
||||||
|
doc2,
|
||||||
|
]
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||||
|
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||||
|
@ -560,6 +560,7 @@ static json oaicompat_completion_params_parse(
|
|||||||
|
|
||||||
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
||||||
json data = json::array();
|
json data = json::array();
|
||||||
|
int32_t n_tokens = 0;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto & elem : embeddings) {
|
for (const auto & elem : embeddings) {
|
||||||
data.push_back(json{
|
data.push_back(json{
|
||||||
@ -567,14 +568,16 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||||||
{"index", i++},
|
{"index", i++},
|
||||||
{"object", "embedding"}
|
{"object", "embedding"}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
n_tokens += json_value(elem, "tokens_evaluated", 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
json res = json {
|
json res = json {
|
||||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
{"object", "list"},
|
{"object", "list"},
|
||||||
{"usage", json { // TODO: fill
|
{"usage", json {
|
||||||
{"prompt_tokens", 0},
|
{"prompt_tokens", n_tokens},
|
||||||
{"total_tokens", 0}
|
{"total_tokens", n_tokens}
|
||||||
}},
|
}},
|
||||||
{"data", data}
|
{"data", data}
|
||||||
};
|
};
|
||||||
@ -584,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||||||
|
|
||||||
static json format_response_rerank(const json & request, const json & ranks) {
|
static json format_response_rerank(const json & request, const json & ranks) {
|
||||||
json data = json::array();
|
json data = json::array();
|
||||||
|
int32_t n_tokens = 0;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto & rank : ranks) {
|
for (const auto & rank : ranks) {
|
||||||
data.push_back(json{
|
data.push_back(json{
|
||||||
{"index", i++},
|
{"index", i++},
|
||||||
{"relevance_score", json_value(rank, "score", 0.0)},
|
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
json res = json {
|
json res = json {
|
||||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
{"object", "list"},
|
{"object", "list"},
|
||||||
{"usage", json { // TODO: fill
|
{"usage", json {
|
||||||
{"prompt_tokens", 0},
|
{"prompt_tokens", n_tokens},
|
||||||
{"total_tokens", 0}
|
{"total_tokens", n_tokens}
|
||||||
}},
|
}},
|
||||||
{"results", data}
|
{"results", data}
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user