mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
server : add support for "encoding_format": "base64" to the */embeddings endpoints (#10967)
* add support for base64 * fix base64 test * improve test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
2cd43f4900
commit
9ba399dfa7
@ -34,6 +34,7 @@ endforeach()
|
|||||||
add_executable(${TARGET} ${TARGET_SRCS})
|
add_executable(${TARGET} ${TARGET_SRCS})
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
|
||||||
|
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
|
||||||
if (LLAMA_SERVER_SSL)
|
if (LLAMA_SERVER_SSL)
|
||||||
|
@ -3790,6 +3790,17 @@ int main(int argc, char ** argv) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool use_base64 = false;
|
||||||
|
if (body.count("encoding_format") != 0) {
|
||||||
|
const std::string& format = body.at("encoding_format");
|
||||||
|
if (format == "base64") {
|
||||||
|
use_base64 = true;
|
||||||
|
} else if (format != "float") {
|
||||||
|
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
||||||
for (const auto & tokens : tokenized_prompts) {
|
for (const auto & tokens : tokenized_prompts) {
|
||||||
// this check is necessary for models that do not add BOS token to the input
|
// this check is necessary for models that do not add BOS token to the input
|
||||||
@ -3841,7 +3852,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
|
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
|
||||||
res_ok(res, root);
|
res_ok(res, root);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import base64
|
||||||
|
import struct
|
||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
@ -194,3 +196,42 @@ def test_embedding_usage_multiple():
|
|||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||||
assert res.body['usage']['prompt_tokens'] == 2 * 9
|
assert res.body['usage']['prompt_tokens'] == 2 * 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_openai_library_base64():
|
||||||
|
server.start()
|
||||||
|
test_input = "Test base64 embedding output"
|
||||||
|
|
||||||
|
# get embedding in default format
|
||||||
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
|
"input": test_input
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
vec0 = res.body["data"][0]["embedding"]
|
||||||
|
|
||||||
|
# get embedding in base64 format
|
||||||
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
|
"input": test_input,
|
||||||
|
"encoding_format": "base64"
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "data" in res.body
|
||||||
|
assert len(res.body["data"]) == 1
|
||||||
|
|
||||||
|
embedding_data = res.body["data"][0]
|
||||||
|
assert "embedding" in embedding_data
|
||||||
|
assert isinstance(embedding_data["embedding"], str)
|
||||||
|
|
||||||
|
# Verify embedding is valid base64
|
||||||
|
decoded = base64.b64decode(embedding_data["embedding"])
|
||||||
|
# Verify decoded data can be converted back to float array
|
||||||
|
float_count = len(decoded) // 4 # 4 bytes per float
|
||||||
|
floats = struct.unpack(f'{float_count}f', decoded)
|
||||||
|
assert len(floats) > 0
|
||||||
|
assert all(isinstance(x, float) for x in floats)
|
||||||
|
assert len(floats) == len(vec0)
|
||||||
|
|
||||||
|
# make sure the decoded data is the same as the original
|
||||||
|
for x, y in zip(floats, vec0):
|
||||||
|
assert abs(x - y) < EPSILON
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "common/base64.hpp"
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
// crash the server in debug mode, otherwise send an http 500 error
|
// crash the server in debug mode, otherwise send an http 500 error
|
||||||
@ -613,16 +614,31 @@ static json oaicompat_completion_params_parse(
|
|||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
|
||||||
json data = json::array();
|
json data = json::array();
|
||||||
int32_t n_tokens = 0;
|
int32_t n_tokens = 0;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto & elem : embeddings) {
|
for (const auto & elem : embeddings) {
|
||||||
data.push_back(json{
|
json embedding_obj;
|
||||||
|
|
||||||
|
if (use_base64) {
|
||||||
|
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
|
||||||
|
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
|
||||||
|
size_t data_size = vec.size() * sizeof(float);
|
||||||
|
embedding_obj = {
|
||||||
|
{"embedding", base64::encode(data_ptr, data_size)},
|
||||||
|
{"index", i++},
|
||||||
|
{"object", "embedding"},
|
||||||
|
{"encoding_format", "base64"}
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
embedding_obj = {
|
||||||
{"embedding", json_value(elem, "embedding", json::array())},
|
{"embedding", json_value(elem, "embedding", json::array())},
|
||||||
{"index", i++},
|
{"index", i++},
|
||||||
{"object", "embedding"}
|
{"object", "embedding"}
|
||||||
});
|
};
|
||||||
|
}
|
||||||
|
data.push_back(embedding_obj);
|
||||||
|
|
||||||
n_tokens += json_value(elem, "tokens_evaluated", 0);
|
n_tokens += json_value(elem, "tokens_evaluated", 0);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user