mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
server : (embeddings) using same format for "input" and "content" (#10872)
* server : (embeddings) using same format for "input" and "content" * fix test case * handle empty input case * fix test
This commit is contained in:
parent
6b064c92b4
commit
46828872c3
@ -3651,25 +3651,33 @@ int main(int argc, char ** argv) {
|
|||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
bool oaicompat = false;
|
bool oaicompat = false;
|
||||||
|
|
||||||
// an input prompt can be a string or a list of tokens (integer)
|
// for the shape of input/content, see tokenize_input_prompts()
|
||||||
json prompt;
|
json prompt;
|
||||||
if (body.count("input") != 0) {
|
if (body.contains("input")) {
|
||||||
oaicompat = true;
|
oaicompat = true;
|
||||||
prompt = body.at("input");
|
prompt = body.at("input");
|
||||||
} else if (body.count("content") != 0) {
|
} else if (body.contains("content")) {
|
||||||
// with "content", we only support single prompt
|
oaicompat = false;
|
||||||
prompt = std::vector<std::string>{body.at("content")};
|
prompt = body.at("content");
|
||||||
} else {
|
} else {
|
||||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
||||||
|
for (const auto & tokens : tokenized_prompts) {
|
||||||
|
// this check is necessary for models that do not add BOS token to the input
|
||||||
|
if (tokens.empty()) {
|
||||||
|
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
|
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
|
@ -45,6 +45,35 @@ def test_embedding_multiple():
|
|||||||
assert len(d['embedding']) > 1
|
assert len(d['embedding']) > 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content,is_multi_prompt",
|
||||||
|
[
|
||||||
|
# single prompt
|
||||||
|
("string", False),
|
||||||
|
([12, 34, 56], False),
|
||||||
|
([12, 34, "string", 56, 78], False),
|
||||||
|
# multiple prompts
|
||||||
|
(["string1", "string2"], True),
|
||||||
|
(["string1", [12, 34, 56]], True),
|
||||||
|
([[12, 34, 56], [12, 34, 56]], True),
|
||||||
|
([[12, 34, 56], [12, "string", 34, 56]], True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_embedding_mixed_input(content, is_multi_prompt: bool):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/embeddings", data={"content": content})
|
||||||
|
assert res.status_code == 200
|
||||||
|
if is_multi_prompt:
|
||||||
|
assert len(res.body) == len(content)
|
||||||
|
for d in res.body:
|
||||||
|
assert 'embedding' in d
|
||||||
|
assert len(d['embedding']) > 1
|
||||||
|
else:
|
||||||
|
assert 'embedding' in res.body
|
||||||
|
assert len(res.body['embedding']) > 1
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_openai_library_single():
|
def test_embedding_openai_library_single():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
@ -102,8 +131,8 @@ def test_same_prompt_give_same_result():
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"content,n_tokens",
|
"content,n_tokens",
|
||||||
[
|
[
|
||||||
("I believe the meaning of life is", 7),
|
("I believe the meaning of life is", 9),
|
||||||
("This is a test", 4),
|
("This is a test", 6),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_embedding_usage_single(content, n_tokens):
|
def test_embedding_usage_single(content, n_tokens):
|
||||||
@ -126,4 +155,4 @@ def test_embedding_usage_multiple():
|
|||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||||
assert res.body['usage']['prompt_tokens'] == 2 * 7
|
assert res.body['usage']['prompt_tokens'] == 2 * 9
|
||||||
|
@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
|
|||||||
* and multiple prompts (multi-tasks):
|
* and multiple prompts (multi-tasks):
|
||||||
* - "prompt": ["string1", "string2"]
|
* - "prompt": ["string1", "string2"]
|
||||||
* - "prompt": ["string1", [12, 34, 56]]
|
* - "prompt": ["string1", [12, 34, 56]]
|
||||||
|
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
||||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
||||||
*/
|
*/
|
||||||
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||||
|
Loading…
Reference in New Issue
Block a user