diff --git a/common/common.cpp b/common/common.cpp index c0c98232e..05d3ba766 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) break; case 0: // max absolute for (int i = 0; i < n; i++) { - if (sum < std::abs(inp[i])) sum = std::abs(inp[i]); + if (sum < std::abs(inp[i])) { + sum = std::abs(inp[i]); + } } sum /= 32760.0; // make an int16 range break; diff --git a/common/common.h b/common/common.h index 5f556c24d..ec0e49f6f 100644 --- a/common/common.h +++ b/common/common.h @@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si // Embedding utils // -void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); +// TODO: repace embd_norm with an enum +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm); float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 6e42fa073..18a945b33 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -75,7 +75,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } std::vector emb_norm(emb_unorm.size()); - common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd); + common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2); result.push_back(emb_norm); #ifdef GRIT_DEBUG diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 23ff4db27..a5c6fe7e5 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } float * out = output + batch.seq_id[i][0] * n_embd; - common_embd_normalize(embd, out, n_embd); + common_embd_normalize(embd, out, n_embd, 2); } } diff --git a/examples/server/README.md b/examples/server/README.md index ecd24c899..d006a8d37 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \ ### POST `/v1/embeddings`: OpenAI-compatible embeddings API +This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm. + *Options:* See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings). @@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r }' ``` +### POST `/embeddings`: non-OpenAI-compatible embeddings API + +This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm. + +Note that the response format of this endpoint is different from `/v1/embeddings`. + +*Options:* + +Same as the `/v1/embeddings` endpoint. + +*Examples:* + +Same as the `/v1/embeddings` endpoint. + +**Response format** + +```json +[ + { + "index": 0, + "embedding": [ + [ ... embeddings for token 0 ... ], + [ ... embeddings for token 1 ... ], + [ ... ] + [ ... embeddings for token N-1 ... ], + ] + }, + ... + { + "index": P, + "embedding": [ + [ ... embeddings for token 0 ... ], + [ ... embeddings for token 1 ... ], + [ ... ] + [ ... embeddings for token N-1 ... ], + ] + } +] +``` + ### GET `/slots`: Returns the current slots processing state > [!WARNING] diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 40aac33f0..5ed4e8d27 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -726,18 +726,32 @@ struct server_task_result_cmpl_partial : server_task_result { struct server_task_result_embd : server_task_result { int index = 0; - std::vector embedding; + std::vector> embedding; int32_t n_tokens; + // OAI-compat fields + bool oaicompat = false; + virtual int get_index() override { return index; } virtual json to_json() override { + return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } + + json to_json_oaicompat() { return json { {"index", index}, - {"embedding", embedding}, + {"embedding", embedding[0]}, {"tokens_evaluated", n_tokens}, }; } @@ -2017,9 +2031,10 @@ struct server_context { void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + res->oaicompat = slot.params.oaicompat; const int n_embd = llama_n_embd(model); @@ -2038,12 +2053,18 @@ struct server_context { if (embd == NULL) { SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - res->embedding = std::vector(n_embd, 0.0f); + res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; } - common_embd_normalize(embd, embd_res.data(), n_embd); - res->embedding = embd_res; + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({ embd, embd + n_embd }); + } } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -2657,7 +2678,10 @@ struct server_context { // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false); + // without pooling, we want to output the embeddings for all the tokens in the batch + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; + + common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3665,14 +3689,17 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) { const json body = json::parse(req.body); - bool oaicompat = false; + + if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } // for the shape of input/content, see tokenize_input_prompts() json prompt; - if (body.contains("input")) { - oaicompat = true; + if (body.count("input") != 0) { prompt = body.at("input"); } else if (body.contains("content")) { oaicompat = false; @@ -3697,10 +3724,15 @@ int main(int argc, char ** argv) { { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.oaicompat = oaicompat; + tasks.push_back(task); } @@ -3728,12 +3760,18 @@ int main(int argc, char ** argv) { } // write JSON response - json root = oaicompat - ? format_embeddings_response_oaicompat(body, responses) - : responses.size() == 1 ? responses[0] : json(responses); + json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses); res_ok(res, root); }; + const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + handle_embeddings_impl(req, res, false); + }; + + const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + handle_embeddings_impl(req, res, true); + }; + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED)); @@ -3907,7 +3945,7 @@ int main(int argc, char ** argv) { svr->Post("/infill", handle_infill); svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); - svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/v1/embeddings", handle_embeddings_oai); svr->Post("/rerank", handle_rerank); svr->Post("/reranking", handle_rerank); svr->Post("/v1/rerank", handle_rerank); diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 4f4e9dcf0..e32d74582 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -14,8 +14,9 @@ def create_server(): def test_embedding_single(): global server + server.pooling = 'last' server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": "I believe the meaning of life is", }) assert res.status_code == 200 @@ -29,8 +30,9 @@ def test_embedding_single(): def test_embedding_multiple(): global server + server.pooling = 'last' server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": [ "I believe the meaning of life is", "Write a joke about AI from a very long prompt which will not be truncated", @@ -46,7 +48,7 @@ def test_embedding_multiple(): @pytest.mark.parametrize( - "content,is_multi_prompt", + "input,is_multi_prompt", [ # single prompt ("string", False), @@ -59,25 +61,55 @@ def test_embedding_multiple(): ([[12, 34, 56], [12, "string", 34, 56]], True), ] ) -def test_embedding_mixed_input(content, is_multi_prompt: bool): +def test_embedding_mixed_input(input, is_multi_prompt: bool): global server server.start() - res = server.make_request("POST", "/embeddings", data={"content": content}) + res = server.make_request("POST", "/v1/embeddings", data={"input": input}) assert res.status_code == 200 + data = res.body['data'] if is_multi_prompt: - assert len(res.body) == len(content) - for d in res.body: + assert len(data) == len(input) + for d in data: assert 'embedding' in d assert len(d['embedding']) > 1 else: - assert 'embedding' in res.body - assert len(res.body['embedding']) > 1 + assert 'embedding' in data[0] + assert len(data[0]['embedding']) > 1 + + +def test_embedding_pooling_none(): + global server + server.pooling = 'none' + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": "hello hello hello", + }) + assert res.status_code == 200 + assert 'embedding' in res.body[0] + assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special + + # make sure embedding vector is not normalized + for x in res.body[0]['embedding']: + assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON + + +def test_embedding_pooling_none_oai(): + global server + server.pooling = 'none' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "hello hello hello", + }) + + # /v1/embeddings does not support pooling type 'none' + assert res.status_code == 400 def test_embedding_openai_library_single(): global server + server.pooling = 'last' server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is") assert len(res.data) == 1 assert len(res.data[0].embedding) > 1 @@ -85,8 +117,9 @@ def test_embedding_openai_library_single(): def test_embedding_openai_library_multiple(): global server + server.pooling = 'last' server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.embeddings.create(model="text-embedding-3-small", input=[ "I believe the meaning of life is", "Write a joke about AI from a very long prompt which will not be truncated", @@ -100,8 +133,9 @@ def test_embedding_openai_library_multiple(): def test_embedding_error_prompt_too_long(): global server + server.pooling = 'last' server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": "This is a test " * 512, }) assert res.status_code != 200 @@ -109,8 +143,9 @@ def test_embedding_error_prompt_too_long(): def test_same_prompt_give_same_result(): + server.pooling = 'last' server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": [ "I believe the meaning of life is", "I believe the meaning of life is", @@ -138,7 +173,7 @@ def test_same_prompt_give_same_result(): def test_embedding_usage_single(content, n_tokens): global server server.start() - res = server.make_request("POST", "/embeddings", data={"input": content}) + res = server.make_request("POST", "/v1/embeddings", data={"input": content}) assert res.status_code == 200 assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] assert res.body['usage']['prompt_tokens'] == n_tokens @@ -147,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens): def test_embedding_usage_multiple(): global server server.start() - res = server.make_request("POST", "/embeddings", data={ + res = server.make_request("POST", "/v1/embeddings", data={ "input": [ "I believe the meaning of life is", "I believe the meaning of life is", diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index d988ccf5e..277125e88 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -65,6 +65,7 @@ class ServerProcess: server_reranking: bool | None = False server_metrics: bool | None = False server_slots: bool | None = False + pooling: str | None = None draft: int | None = None api_key: str | None = None response_format: str | None = None @@ -132,6 +133,8 @@ class ServerProcess: server_args.append("--metrics") if self.server_slots: server_args.append("--slots") + if self.pooling: + server_args.extend(["--pooling", self.pooling]) if self.model_alias: server_args.extend(["--alias", self.model_alias]) if self.n_ctx: