mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
llama : fix embeddings
ggml-ci
This commit is contained in:
parent
a0fc62661f
commit
d0347840c1
@ -1299,7 +1299,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||||||
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||||
cparams.seed = params.seed;
|
cparams.seed = params.seed;
|
||||||
cparams.logits_all = params.logits_all;
|
cparams.logits_all = params.logits_all;
|
||||||
cparams.embedding = params.embedding;
|
cparams.embeddings = params.embedding;
|
||||||
cparams.rope_scaling_type = params.rope_scaling_type;
|
cparams.rope_scaling_type = params.rope_scaling_type;
|
||||||
cparams.rope_freq_base = params.rope_freq_base;
|
cparams.rope_freq_base = params.rope_freq_base;
|
||||||
cparams.rope_freq_scale = params.rope_freq_scale;
|
cparams.rope_freq_scale = params.rope_freq_scale;
|
||||||
|
@ -19,7 +19,7 @@ static std::vector<std::string> split_lines(const std::string & s) {
|
|||||||
|
|
||||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
||||||
for (size_t i = 0; i < tokens.size(); i++) {
|
for (size_t i = 0; i < tokens.size(); i++) {
|
||||||
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
|
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,9 +45,13 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// normalize on copy
|
// normalize on copy
|
||||||
for (int k = 0; k < n_seq; k++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
if (!batch.logits[i]) {
|
||||||
float * out = output + k * n_embd;
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
float * emb = llama_get_embeddings_ith(ctx, i);
|
||||||
|
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||||
normalize(emb, out, n_embd);
|
normalize(emb, out, n_embd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -145,6 +149,7 @@ int main(int argc, char ** argv) {
|
|||||||
for (int k = 0; k < n_prompts; k++) {
|
for (int k = 0; k < n_prompts; k++) {
|
||||||
// clamp to n_batch tokens
|
// clamp to n_batch tokens
|
||||||
auto & inp = inputs[k];
|
auto & inp = inputs[k];
|
||||||
|
|
||||||
const uint64_t n_toks = inp.size();
|
const uint64_t n_toks = inp.size();
|
||||||
|
|
||||||
// encode if at capacity
|
// encode if at capacity
|
||||||
|
34
examples/server-embd.py
Normal file
34
examples/server-embd.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
n = 8
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
async def requests_post_async(*args, **kwargs):
|
||||||
|
return await asyncio.to_thread(requests.post, *args, **kwargs)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
model_url = "http://127.0.0.1:6900"
|
||||||
|
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
|
||||||
|
url= f"{model_url}/embedding",
|
||||||
|
json= {"content": str(i)*32}
|
||||||
|
) for i in range(n)])
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
embedding = response.json()["embedding"]
|
||||||
|
print(embedding[-8:])
|
||||||
|
result.append(embedding)
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
# compute cosine similarity
|
||||||
|
|
||||||
|
for i in range(n-1):
|
||||||
|
for j in range(i+1, n):
|
||||||
|
embedding1 = np.array(result[i])
|
||||||
|
embedding2 = np.array(result[j])
|
||||||
|
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
|
||||||
|
print(f"Similarity between {i} and {j}: {similarity:.2f}")
|
||||||
|
|
@ -1210,7 +1210,7 @@ struct llama_server_context
|
|||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_embedding(server_slot &slot)
|
void send_embedding(server_slot & slot, const llama_batch & batch)
|
||||||
{
|
{
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
@ -1219,6 +1219,7 @@ struct llama_server_context
|
|||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(model);
|
const int n_embd = llama_n_embd(model);
|
||||||
|
|
||||||
if (!params.embedding)
|
if (!params.embedding)
|
||||||
{
|
{
|
||||||
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
|
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
|
||||||
@ -1229,13 +1230,20 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const float *data = llama_get_embeddings(ctx);
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * data = llama_get_embeddings_ith(ctx, i);
|
||||||
std::vector<float> embedding(data, data + n_embd);
|
std::vector<float> embedding(data, data + n_embd);
|
||||||
|
|
||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"embedding", embedding},
|
{"embedding", embedding },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
}
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1845,7 +1853,7 @@ struct llama_server_context
|
|||||||
ga_i += ga_w/ga_n;
|
ga_i += ga_w/ga_n;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
|
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
||||||
slot_npast++;
|
slot_npast++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1881,7 +1889,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
||||||
{
|
{
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||||
|
|
||||||
for (auto & slot : slots)
|
for (auto & slot : slots)
|
||||||
{
|
{
|
||||||
@ -1954,7 +1962,7 @@ struct llama_server_context
|
|||||||
// prompt evaluated for embedding
|
// prompt evaluated for embedding
|
||||||
if (slot.embedding)
|
if (slot.embedding)
|
||||||
{
|
{
|
||||||
send_embedding(slot);
|
send_embedding(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
continue;
|
continue;
|
||||||
@ -2330,7 +2338,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_batch = std::stoi(argv[i]);
|
params.n_batch = std::stoi(argv[i]);
|
||||||
params.n_batch = std::min(512, params.n_batch);
|
|
||||||
}
|
}
|
||||||
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
|
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
|
||||||
{
|
{
|
||||||
|
95
llama.cpp
95
llama.cpp
@ -1682,7 +1682,9 @@ struct llama_cparams {
|
|||||||
float yarn_beta_slow;
|
float yarn_beta_slow;
|
||||||
float defrag_thold;
|
float defrag_thold;
|
||||||
|
|
||||||
|
bool embeddings;
|
||||||
bool offload_kqv;
|
bool offload_kqv;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type;
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval;
|
ggml_backend_sched_eval_callback cb_eval;
|
||||||
@ -1972,7 +1974,7 @@ struct llama_context {
|
|||||||
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||||
int32_t n_eval = 0; // number of eval calls
|
int32_t n_eval = 0; // number of eval calls
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// logits output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
// guard against access to unset logits
|
// guard against access to unset logits
|
||||||
@ -1980,8 +1982,8 @@ struct llama_context {
|
|||||||
#endif
|
#endif
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
|
||||||
// input embedding (1-dimensional array: [n_embd])
|
// embeddings output (2-dimensional array: [n_tokens][n_embd])
|
||||||
std::vector<float> embedding;
|
std::vector<float> embeddings;
|
||||||
|
|
||||||
// memory buffers used to evaluate the model
|
// memory buffers used to evaluate the model
|
||||||
std::vector<uint8_t> buf_compute_meta;
|
std::vector<uint8_t> buf_compute_meta;
|
||||||
@ -5092,6 +5094,7 @@ static struct ggml_tensor * llm_build_kv(
|
|||||||
llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
|
cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
|
||||||
q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
|
q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
@ -6085,6 +6088,7 @@ struct llm_build_context {
|
|||||||
|
|
||||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
@ -6092,6 +6096,7 @@ struct llm_build_context {
|
|||||||
|
|
||||||
// get input vectors with right size
|
// get input vectors with right size
|
||||||
const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
|
const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
|
||||||
|
|
||||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||||
struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
|
struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
|
||||||
struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
|
struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
|
||||||
@ -8196,16 +8201,16 @@ static int llama_decode_internal(
|
|||||||
|
|
||||||
// the output is always the last tensor in the graph
|
// the output is always the last tensor in the graph
|
||||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
||||||
|
|
||||||
if (strcmp(res->name, "result_output") == 0) {
|
if (strcmp(res->name, "result_output") == 0) {
|
||||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||||
if (strcmp(embeddings->name, "result_norm") != 0) {
|
if (strcmp(embd->name, "result_norm") != 0) {
|
||||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
embd = gf->nodes[gf->n_nodes - 3];
|
||||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
||||||
}
|
}
|
||||||
} else if (strcmp(res->name, "result_embd") == 0) {
|
} else if (strcmp(res->name, "result_embd") == 0) {
|
||||||
embeddings = res;
|
embd = res;
|
||||||
res = nullptr;
|
res = nullptr;
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
@ -8275,46 +8280,57 @@ static int llama_decode_internal(
|
|||||||
logits_out.clear();
|
logits_out.clear();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res);
|
ggml_backend_t backend_res = ggml_backend_sched_get_node_backend(lctx.sched, res);
|
||||||
GGML_ASSERT(res_backend != nullptr);
|
GGML_ASSERT(backend_res != nullptr);
|
||||||
|
|
||||||
if (batch.logits) {
|
if (batch.logits) {
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
if (batch.logits[i] == 0) {
|
if (batch.logits[i] == 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
|
ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
logits_valid[i] = true;
|
logits_valid[i] = true;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all) {
|
} else if (lctx.logits_all) {
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
|
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
logits_out.resize(n_vocab);
|
logits_out.resize(n_vocab);
|
||||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
|
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
logits_valid[0] = true;
|
logits_valid[0] = true;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
ggml_backend_synchronize(res_backend);
|
ggml_backend_synchronize(backend_res);
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract embeddings
|
// extract embeddings
|
||||||
if (!lctx.embedding.empty()) {
|
if (cparams.embeddings && embd) {
|
||||||
auto & embedding_out = lctx.embedding;
|
auto & embeddings_out = lctx.embeddings;
|
||||||
|
|
||||||
const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0;
|
ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
|
||||||
const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
|
GGML_ASSERT(backend_embd != nullptr);
|
||||||
|
|
||||||
embedding_out.resize(embd_size);
|
if (batch.logits) {
|
||||||
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
embeddings_out.resize(n_embd * n_tokens);
|
||||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
ggml_backend_synchronize(embeddings_backend);
|
if (batch.logits[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
|
||||||
|
} else {
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ggml_backend_synchronize(backend_embd);
|
||||||
}
|
}
|
||||||
|
|
||||||
// measure the performance only for the single-token evals
|
// measure the performance only for the single-token evals
|
||||||
@ -11864,7 +11880,7 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.type_k =*/ GGML_TYPE_F16,
|
/*.type_k =*/ GGML_TYPE_F16,
|
||||||
/*.type_v =*/ GGML_TYPE_F16,
|
/*.type_v =*/ GGML_TYPE_F16,
|
||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.embedding =*/ false,
|
/*.embeddings =*/ false,
|
||||||
/*.offload_kqv =*/ true,
|
/*.offload_kqv =*/ true,
|
||||||
/*.abort_callback =*/ nullptr,
|
/*.abort_callback =*/ nullptr,
|
||||||
/*.abort_callback_data =*/ nullptr,
|
/*.abort_callback_data =*/ nullptr,
|
||||||
@ -12015,6 +12031,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||||
cparams.defrag_thold = params.defrag_thold;
|
cparams.defrag_thold = params.defrag_thold;
|
||||||
|
cparams.embeddings = params.embeddings;
|
||||||
cparams.offload_kqv = params.offload_kqv;
|
cparams.offload_kqv = params.offload_kqv;
|
||||||
cparams.pooling_type = params.pooling_type;
|
cparams.pooling_type = params.pooling_type;
|
||||||
|
|
||||||
@ -12192,8 +12209,8 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
// resized during inference, reserve maximum
|
// resized during inference, reserve maximum
|
||||||
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
|
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
|
||||||
|
|
||||||
if (params.embedding) {
|
if (params.embeddings) {
|
||||||
ctx->embedding.resize(hparams.n_embd);
|
ctx->embeddings.reserve(hparams.n_embd*cparams.n_batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
// graph inputs
|
// graph inputs
|
||||||
@ -12628,7 +12645,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||||||
// assume worst case for logits although only currently set ones are serialized
|
// assume worst case for logits although only currently set ones are serialized
|
||||||
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
||||||
const size_t s_embedding_size = sizeof(size_t);
|
const size_t s_embedding_size = sizeof(size_t);
|
||||||
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
|
const size_t s_embedding = ctx->embeddings.capacity() * sizeof(float);
|
||||||
const size_t s_kv_buf_size = sizeof(size_t);
|
const size_t s_kv_buf_size = sizeof(size_t);
|
||||||
const size_t s_kv_head = sizeof(uint32_t);
|
const size_t s_kv_head = sizeof(uint32_t);
|
||||||
const size_t s_kv_size = sizeof(uint32_t);
|
const size_t s_kv_size = sizeof(uint32_t);
|
||||||
@ -12737,12 +12754,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||||||
|
|
||||||
// copy embeddings
|
// copy embeddings
|
||||||
{
|
{
|
||||||
const size_t embedding_size = ctx->embedding.size();
|
const size_t embeddings_size = ctx->embeddings.size();
|
||||||
|
|
||||||
data_ctx->write(&embedding_size, sizeof(embedding_size));
|
data_ctx->write(&embeddings_size, sizeof(embeddings_size));
|
||||||
|
|
||||||
if (embedding_size) {
|
if (embeddings_size) {
|
||||||
data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float));
|
data_ctx->write(ctx->embeddings.data(), embeddings_size * sizeof(float));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -12846,15 +12863,17 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||||||
|
|
||||||
// set embeddings
|
// set embeddings
|
||||||
{
|
{
|
||||||
size_t embedding_size;
|
size_t embeddings_size;
|
||||||
|
|
||||||
memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size);
|
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
|
||||||
|
|
||||||
GGML_ASSERT(ctx->embedding.capacity() == embedding_size);
|
GGML_ASSERT(ctx->embeddings.capacity() == embeddings_size);
|
||||||
|
|
||||||
if (embedding_size) {
|
if (embeddings_size) {
|
||||||
memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float));
|
ctx->embeddings.resize(embeddings_size);
|
||||||
inp += embedding_size * sizeof(float);
|
|
||||||
|
memcpy(ctx->embeddings.data(), inp, embeddings_size * sizeof(float));
|
||||||
|
inp += embeddings_size * sizeof(float);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -13104,11 +13123,11 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
return ctx->embedding.data();
|
return ctx->embeddings.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||||
return ctx->embedding.data() + i*ctx->model.hparams.n_embd;
|
return ctx->embeddings.data() + i*ctx->model.hparams.n_embd;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
||||||
|
8
llama.h
8
llama.h
@ -163,7 +163,7 @@ extern "C" {
|
|||||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||||
// - pos : the positions of the respective token in the sequence
|
// - pos : the positions of the respective token in the sequence
|
||||||
// - seq_id : the sequence to which the respective token belongs
|
// - seq_id : the sequence to which the respective token belongs
|
||||||
// - logits : if zero, the logits for the respective token will not be output
|
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||||
//
|
//
|
||||||
typedef struct llama_batch {
|
typedef struct llama_batch {
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
@ -173,7 +173,7 @@ extern "C" {
|
|||||||
llama_pos * pos;
|
llama_pos * pos;
|
||||||
int32_t * n_seq_id;
|
int32_t * n_seq_id;
|
||||||
llama_seq_id ** seq_id;
|
llama_seq_id ** seq_id;
|
||||||
int8_t * logits;
|
int8_t * logits; // TODO: rename this to "output"
|
||||||
|
|
||||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||||
// for future-proof code, use the above fields instead and ignore everything below
|
// for future-proof code, use the above fields instead and ignore everything below
|
||||||
@ -260,7 +260,7 @@ extern "C" {
|
|||||||
|
|
||||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||||
bool embedding; // embedding mode only
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
|
|
||||||
// Abort callback
|
// Abort callback
|
||||||
@ -659,7 +659,7 @@ extern "C" {
|
|||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Get the embeddings for the ith sequence
|
// Get the embeddings for the ith token
|
||||||
// llama_get_embeddings(ctx) + i*n_embd
|
// llama_get_embeddings(ctx) + i*n_embd
|
||||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user