mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
server : normalize embeddings (#5956)
* output normalize embedding in '/v1/embeddings' * common : reuse llama_embd_normalize * common : better normalize impl --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
2c4f566c88
commit
fb215c3832
@ -1852,3 +1852,18 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
|
|||||||
|
|
||||||
printf("\n=== Done dumping\n");
|
printf("\n=== Done dumping\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_embd_normalize(const float * inp, float * out, int n) {
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
sum += inp[i] * inp[i];
|
||||||
|
}
|
||||||
|
sum = sqrt(sum);
|
||||||
|
|
||||||
|
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
|
||||||
|
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
out[i] = inp[i] * norm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -260,3 +260,10 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
|
|||||||
|
|
||||||
// Dump the KV cache view showing individual sequences in each cell (long output).
|
// Dump the KV cache view showing individual sequences in each cell (long output).
|
||||||
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Embedding utils
|
||||||
|
//
|
||||||
|
|
||||||
|
void llama_embd_normalize(const float * inp, float * out, int n);
|
||||||
|
|
||||||
|
@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void normalize(const float * vec, float * out, int n) {
|
|
||||||
float norm = 0;
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
norm += vec[i] * vec[i];
|
|
||||||
}
|
|
||||||
norm = sqrt(norm);
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
out[i] = vec[i] / norm;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||||||
fprintf(stderr, "%s : failed to decode\n", __func__);
|
fprintf(stderr, "%s : failed to decode\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalize on copy
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
if (!batch.logits[i]) {
|
if (!batch.logits[i]) {
|
||||||
continue;
|
continue;
|
||||||
@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||||
normalize(embd, out, n_embd);
|
llama_embd_normalize(embd, out, n_embd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1327,6 +1327,8 @@ struct server_context {
|
|||||||
|
|
||||||
const int n_embd = llama_n_embd(model);
|
const int n_embd = llama_n_embd(model);
|
||||||
|
|
||||||
|
std::vector<float> embd_res(n_embd, 0.0f);
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||||
continue;
|
continue;
|
||||||
@ -1350,8 +1352,10 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_embd_normalize(embd, embd_res.data(), n_embd);
|
||||||
|
|
||||||
res.data = json {
|
res.data = json {
|
||||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
{"embedding", embd_res},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3354,6 +3358,8 @@ int main(int argc, char ** argv) {
|
|||||||
// get the result
|
// get the result
|
||||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
|
||||||
|
// append to the responses
|
||||||
responses.push_back(result.data);
|
responses.push_back(result.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user