server : (embeddings) using same format for "input" and "content" (#10872)

* server : (embeddings) using same format for "input" and "content"

* fix test case

* handle empty input case

* fix test
This commit is contained in:
Xuan Son Nguyen 2024-12-18 09:55:09 +01:00 committed by GitHub
parent 6b064c92b4
commit 46828872c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 9 deletions

View File

@ -3651,25 +3651,33 @@ int main(int argc, char ** argv) {
const json body = json::parse(req.body); const json body = json::parse(req.body);
bool oaicompat = false; bool oaicompat = false;
// an input prompt can be a string or a list of tokens (integer) // for the shape of input/content, see tokenize_input_prompts()
json prompt; json prompt;
if (body.count("input") != 0) { if (body.contains("input")) {
oaicompat = true; oaicompat = true;
prompt = body.at("input"); prompt = body.at("input");
} else if (body.count("content") != 0) { } else if (body.contains("content")) {
// with "content", we only support single prompt oaicompat = false;
prompt = std::vector<std::string>{body.at("content")}; prompt = body.at("content");
} else { } else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) {
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
// create and queue the task // create and queue the task
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks; std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
task.id = ctx_server.queue_tasks.get_new_id(); task.id = ctx_server.queue_tasks.get_new_id();

View File

@ -45,6 +45,35 @@ def test_embedding_multiple():
assert len(d['embedding']) > 1 assert len(d['embedding']) > 1
@pytest.mark.parametrize(
"content,is_multi_prompt",
[
# single prompt
("string", False),
([12, 34, 56], False),
([12, 34, "string", 56, 78], False),
# multiple prompts
(["string1", "string2"], True),
(["string1", [12, 34, 56]], True),
([[12, 34, 56], [12, 34, 56]], True),
([[12, 34, 56], [12, "string", 34, 56]], True),
]
)
def test_embedding_mixed_input(content, is_multi_prompt: bool):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"content": content})
assert res.status_code == 200
if is_multi_prompt:
assert len(res.body) == len(content)
for d in res.body:
assert 'embedding' in d
assert len(d['embedding']) > 1
else:
assert 'embedding' in res.body
assert len(res.body['embedding']) > 1
def test_embedding_openai_library_single(): def test_embedding_openai_library_single():
global server global server
server.start() server.start()
@ -102,8 +131,8 @@ def test_same_prompt_give_same_result():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"content,n_tokens", "content,n_tokens",
[ [
("I believe the meaning of life is", 7), ("I believe the meaning of life is", 9),
("This is a test", 4), ("This is a test", 6),
] ]
) )
def test_embedding_usage_single(content, n_tokens): def test_embedding_usage_single(content, n_tokens):
@ -126,4 +155,4 @@ def test_embedding_usage_multiple():
}) })
assert res.status_code == 200 assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 7 assert res.body['usage']['prompt_tokens'] == 2 * 9

View File

@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
* and multiple prompts (multi-tasks): * and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"] * - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]] * - "prompt": ["string1", [12, 34, 56]]
* - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/ */
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {