server : output embeddings for all tokens when pooling = none (#10861)

* server : add "tokens" output

ggml-ci

* server : output embeddings for all tokens when pooling = none

ggml-ci

* server : update readme [no ci]

* server : fix spacing [no ci]

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>

* server : be explicit about the pooling type in the tests

ggml-ci

* server : update /embeddings and /v1/embeddings endpoints

ggml-ci

* server : do not normalize embeddings when there is no pooling

ggml-ci

* server : update readme

ggml-ci

* server : fixes

* tests : update server tests

ggml-ci

* server : update readme [no ci]

* server : remove rebase artifact

---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
Georgi Gerganov 2024-12-18 13:01:41 +02:00 committed by GitHub
parent 0e70ba686e
commit 152610eda9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 158 additions and 37 deletions

View File

@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
if (sum < std::abs(inp[i])) {
sum = std::abs(inp[i]);
}
}
sum /= 32760.0; // make an int16 range
break;

View File

@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
// Embedding utils
//
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
// TODO: repace embd_norm with an enum
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);

View File

@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}
std::vector<float> emb_norm(emb_unorm.size());
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
result.push_back(emb_norm);
#ifdef GRIT_DEBUG

View File

@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd);
common_embd_normalize(embd, out, n_embd, 2);
}
}

View File

@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \
### POST `/v1/embeddings`: OpenAI-compatible embeddings API
This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
*Options:*
See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
}'
```
### POST `/embeddings`: non-OpenAI-compatible embeddings API
This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.
Note that the response format of this endpoint is different from `/v1/embeddings`.
*Options:*
Same as the `/v1/embeddings` endpoint.
*Examples:*
Same as the `/v1/embeddings` endpoint.
**Response format**
```json
[
{
"index": 0,
"embedding": [
[ ... embeddings for token 0 ... ],
[ ... embeddings for token 1 ... ],
[ ... ]
[ ... embeddings for token N-1 ... ],
]
},
...
{
"index": P,
"embedding": [
[ ... embeddings for token 0 ... ],
[ ... embeddings for token 1 ... ],
[ ... ]
[ ... embeddings for token N-1 ... ],
]
}
]
```
### GET `/slots`: Returns the current slots processing state
> [!WARNING]

View File

@ -726,18 +726,32 @@ struct server_task_result_cmpl_partial : server_task_result {
struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<float> embedding;
std::vector<std::vector<float>> embedding;
int32_t n_tokens;
// OAI-compat fields
bool oaicompat = false;
virtual int get_index() override {
return index;
}
virtual json to_json() override {
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
}
json to_json_non_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
};
}
json to_json_oaicompat() {
return json {
{"index", index},
{"embedding", embedding},
{"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
@ -2017,9 +2031,10 @@ struct server_context {
void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat;
const int n_embd = llama_n_embd(model);
@ -2038,12 +2053,18 @@ struct server_context {
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->embedding = std::vector<float>(n_embd, 0.0f);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
}
common_embd_normalize(embd, embd_res.data(), n_embd);
res->embedding = embd_res;
// normalize only when there is pooling
// TODO: configurable
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res);
} else {
res->embedding.push_back({ embd, embd + n_embd });
}
}
SLT_DBG(slot, "%s", "sending embeddings\n");
@ -2657,7 +2678,10 @@ struct server_context {
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@ -3665,14 +3689,17 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
const json body = json::parse(req.body);
bool oaicompat = false;
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}
// for the shape of input/content, see tokenize_input_prompts()
json prompt;
if (body.contains("input")) {
oaicompat = true;
if (body.count("input") != 0) {
prompt = body.at("input");
} else if (body.contains("content")) {
oaicompat = false;
@ -3697,10 +3724,15 @@ int main(int argc, char ** argv) {
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
// OAI-compat
task.params.oaicompat = oaicompat;
tasks.push_back(task);
}
@ -3728,12 +3760,18 @@ int main(int argc, char ** argv) {
}
// write JSON response
json root = oaicompat
? format_embeddings_response_oaicompat(body, responses)
: responses.size() == 1 ? responses[0] : json(responses);
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
res_ok(res, root);
};
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, false);
};
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, true);
};
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
@ -3907,7 +3945,7 @@ int main(int argc, char ** argv) {
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings_oai);
svr->Post("/rerank", handle_rerank);
svr->Post("/reranking", handle_rerank);
svr->Post("/v1/rerank", handle_rerank);

View File

@ -14,8 +14,9 @@ def create_server():
def test_embedding_single():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
@ -29,8 +30,9 @@ def test_embedding_single():
def test_embedding_multiple():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
@ -46,7 +48,7 @@ def test_embedding_multiple():
@pytest.mark.parametrize(
"content,is_multi_prompt",
"input,is_multi_prompt",
[
# single prompt
("string", False),
@ -59,25 +61,55 @@ def test_embedding_multiple():
([[12, 34, 56], [12, "string", 34, 56]], True),
]
)
def test_embedding_mixed_input(content, is_multi_prompt: bool):
def test_embedding_mixed_input(input, is_multi_prompt: bool):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"content": content})
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
assert res.status_code == 200
data = res.body['data']
if is_multi_prompt:
assert len(res.body) == len(content)
for d in res.body:
assert len(data) == len(input)
for d in data:
assert 'embedding' in d
assert len(d['embedding']) > 1
else:
assert 'embedding' in res.body
assert len(res.body['embedding']) > 1
assert 'embedding' in data[0]
assert len(data[0]['embedding']) > 1
def test_embedding_pooling_none():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "hello hello hello",
})
assert res.status_code == 200
assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
# make sure embedding vector is not normalized
for x in res.body[0]['embedding']:
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
def test_embedding_pooling_none_oai():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "hello hello hello",
})
# /v1/embeddings does not support pooling type 'none'
assert res.status_code == 400
def test_embedding_openai_library_single():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1
@ -85,8 +117,9 @@ def test_embedding_openai_library_single():
def test_embedding_openai_library_multiple():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
@ -100,8 +133,9 @@ def test_embedding_openai_library_multiple():
def test_embedding_error_prompt_too_long():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
@ -109,8 +143,9 @@ def test_embedding_error_prompt_too_long():
def test_same_prompt_give_same_result():
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
@ -138,7 +173,7 @@ def test_same_prompt_give_same_result():
def test_embedding_usage_single(content, n_tokens):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"input": content})
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
@ -147,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
def test_embedding_usage_multiple():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",

View File

@ -65,6 +65,7 @@ class ServerProcess:
server_reranking: bool | None = False
server_metrics: bool | None = False
server_slots: bool | None = False
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
response_format: str | None = None
@ -132,6 +133,8 @@ class ServerProcess:
server_args.append("--metrics")
if self.server_slots:
server_args.append("--slots")
if self.pooling:
server_args.extend(["--pooling", self.pooling])
if self.model_alias:
server_args.extend(["--alias", self.model_alias])
if self.n_ctx: