llama : fix embeddings (#5796)

* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list
This commit is contained in:
Georgi Gerganov 2024-03-04 22:31:20 +02:00 committed by GitHub
parent e0843afe1b
commit 29ae62d2ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 359 additions and 134 deletions

View File

@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
### Recent API changes ### Recent API changes
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849 - [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849
### Hot topics ### Hot topics

View File

@ -1292,7 +1292,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed; cparams.seed = params.seed;
cparams.logits_all = params.logits_all; cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding; cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type; cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base; cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale; cparams.rope_freq_scale = params.rope_freq_scale;

View File

@ -19,11 +19,11 @@ static std::vector<std::string> split_lines(const std::string & s) {
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) { for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, false); llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
} }
} }
static void normalize(float * vec, float * out, int n) { static void normalize(const float * vec, float * out, int n) {
float norm = 0; float norm = 0;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i]; norm += vec[i] * vec[i];
@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
} }
// normalize on copy // normalize on copy
for (int k = 0; k < n_seq; k++) { for (int i = 0; i < batch.n_tokens; i++) {
float * emb = llama_get_embeddings_ith(ctx, k); if (!batch.logits[i]) {
float * out = output + k * n_embd; continue;
normalize(emb, out, n_embd); }
// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}
float * out = output + batch.seq_id[i][0] * n_embd;
normalize(embd, out, n_embd);
} }
} }
@ -132,7 +145,7 @@ int main(int argc, char ** argv) {
// initialize batch // initialize batch
const int n_prompts = prompts.size(); const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts); struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output // allocate output
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
@ -145,6 +158,7 @@ int main(int argc, char ** argv) {
for (int k = 0; k < n_prompts; k++) { for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens // clamp to n_batch tokens
auto & inp = inputs[k]; auto & inp = inputs[k];
const uint64_t n_toks = inp.size(); const uint64_t n_toks = inp.size();
// encode if at capacity // encode if at capacity

34
examples/server-embd.py Normal file
View File

@ -0,0 +1,34 @@
import asyncio
import requests
import numpy as np
n = 8
result = []
async def requests_post_async(*args, **kwargs):
return await asyncio.to_thread(requests.post, *args, **kwargs)
async def main():
model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding",
json= {"content": str(i)*1024}
) for i in range(n)])
for response in responses:
embedding = response.json()["embedding"]
print(embedding[-8:])
result.append(embedding)
asyncio.run(main())
# compute cosine similarity
for i in range(n-1):
for j in range(i+1, n):
embedding1 = np.array(result[i])
embedding2 = np.array(result[j])
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
print(f"Similarity between {i} and {j}: {similarity:.2f}")

View File

@ -1210,7 +1210,7 @@ struct llama_server_context
queue_results.send(res); queue_results.send(res);
} }
void send_embedding(server_slot &slot) void send_embedding(server_slot & slot, const llama_batch & batch)
{ {
task_result res; task_result res;
res.id = slot.task_id; res.id = slot.task_id;
@ -1219,6 +1219,7 @@ struct llama_server_context
res.stop = true; res.stop = true;
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
if (!params.embedding) if (!params.embedding)
{ {
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}}); LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
@ -1229,12 +1230,29 @@ struct llama_server_context
} }
else else
{ {
const float *data = llama_get_embeddings(ctx); for (int i = 0; i < batch.n_tokens; ++i) {
std::vector<float> embedding(data, data + n_embd); if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
res.result_json = json res.result_json = json
{ {
{"embedding", embedding}, {"embedding", std::vector<float>(n_embd, 0.0f)},
}; };
continue;
}
}
res.result_json = json
{
{"embedding", std::vector<float>(embd, embd + n_embd)},
};
}
} }
queue_results.send(res); queue_results.send(res);
} }
@ -1845,7 +1863,7 @@ struct llama_server_context
ga_i += ga_w/ga_n; ga_i += ga_w/ga_n;
} }
} }
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
slot_npast++; slot_npast++;
} }
@ -1881,7 +1899,7 @@ struct llama_server_context
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{ {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (auto & slot : slots) for (auto & slot : slots)
{ {
@ -1954,7 +1972,7 @@ struct llama_server_context
// prompt evaluated for embedding // prompt evaluated for embedding
if (slot.embedding) if (slot.embedding)
{ {
send_embedding(slot); send_embedding(slot, batch_view);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
continue; continue;
@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --pooling {none,mean,cls}\n");
printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.yarn_beta_slow = std::stof(argv[i]); params.yarn_beta_slow = std::stof(argv[i]);
} }
else if (arg == "--pooling")
{
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; }
}
else if (arg == "--threads" || arg == "-t") else if (arg == "--threads" || arg == "-t")
{ {
if (++i >= argc) if (++i >= argc)
@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break; break;
} }
params.n_batch = std::stoi(argv[i]); params.n_batch = std::stoi(argv[i]);
params.n_batch = std::min(512, params.n_batch);
} }
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{ {

285
llama.cpp
View File

@ -1682,7 +1682,9 @@ struct llama_cparams {
float yarn_beta_slow; float yarn_beta_slow;
float defrag_thold; float defrag_thold;
bool embeddings;
bool offload_kqv; bool offload_kqv;
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;
@ -1972,7 +1974,7 @@ struct llama_context {
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
int32_t n_eval = 0; // number of eval calls int32_t n_eval = 0; // number of eval calls
// decode output (2-dimensional array: [n_tokens][n_vocab]) // logits output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits; std::vector<float> logits;
#ifndef NDEBUG #ifndef NDEBUG
// guard against access to unset logits // guard against access to unset logits
@ -1980,8 +1982,13 @@ struct llama_context {
#endif #endif
bool logits_all = false; bool logits_all = false;
// input embedding (1-dimensional array: [n_embd]) // embeddings output (2-dimensional array: [n_tokens][n_embd])
std::vector<float> embedding; // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
std::vector<float> embd;
// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
// memory buffers used to evaluate the model // memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta; std::vector<uint8_t> buf_compute_meta;
@ -5092,6 +5099,7 @@ static struct ggml_tensor * llm_build_kv(
llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
struct ggml_tensor * cur; struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
@ -6085,6 +6093,7 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
struct ggml_tensor * cur; struct ggml_tensor * cur;
@ -6092,6 +6101,7 @@ struct llm_build_context {
// get input vectors with right size // get input vectors with right size
const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type); const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0); struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0); struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
@ -6112,39 +6122,38 @@ struct llm_build_context {
cb(inpL, "inp_norm", -1); cb(inpL, "inp_norm", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_cont(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_tokens, n_tokens, n_tokens*ggml_type_size(lctx.inp_KQ_mask->type), 0));
cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens] cb(KQ_mask, "KQ_mask", -1); // [n_tokens, n_tokens]
// iterate layers // iterate layers
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur = inpL; struct ggml_tensor * cur = inpL;
struct ggml_tensor * Qcur;
struct ggml_tensor * Kcur;
struct ggml_tensor * Vcur;
// self-attention // self-attention
if (model.arch == LLM_ARCH_BERT) { if (model.arch == LLM_ARCH_BERT) {
struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
// seems like we just need to do this for Q?
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
} else { } else {
// compute Q and K and RoPE them // compute Q and K and RoPE them
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il); cb(cur, "wqkv", il);
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
@ -6163,13 +6172,41 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
} }
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
cb(cur, "kqv_merged_cont", il);
ggml_build_forward_expand(gf, cur);
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
if (model.layers[il].bo) {
cb(cur, "kqv_wo", il);
}
if (model.layers[il].bo) {
cur = ggml_add(ctx0, cur, model.layers[il].bo);
}
cb(cur, "kqv_out", il);
// re-add the layer input // re-add the layer input
cur = ggml_add(ctx0, cur, inpL); cur = ggml_add(ctx0, cur, inpL);
@ -6209,16 +6246,29 @@ struct llm_build_context {
// final output // final output
cur = inpL; cur = inpL;
cb(cur, "result_embd", -1);
// pooling layer // pooling layer
if (pooling_type == LLAMA_POOLING_TYPE_MEAN) { switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// nop
} break;
case LLAMA_POOLING_TYPE_MEAN:
{
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
} else if (pooling_type == LLAMA_POOLING_TYPE_CLS) { cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_CLS:
{
cur = ggml_get_rows(ctx0, cur, inp_cls); cur = ggml_get_rows(ctx0, cur, inp_cls);
} else { cb(cur, "result_embd_pooled", -1);
GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type"); } break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "Invalid pooling type");
} break;
} }
cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -7980,7 +8030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
} }
{ if (hparams.causal_attn) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
@ -7995,16 +8045,40 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
float f; float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
(hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
f = -INFINITY; f = -INFINITY;
} else { } else {
f = 0; f = 0.0f;
} }
data[h*(n_kv*n_tokens) + j*n_kv + i] = f; data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
} }
} }
} }
} else {
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
const int64_t n_tokens = batch.n_tokens;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
float * data = (float *) lctx.inp_KQ_mask->data;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_tokens; ++i) {
float f = -INFINITY;
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id[i][s] == seq_id) {
f = 0.0f;
break;
}
}
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
}
}
}
} }
if (hparams.need_kq_pos) { if (hparams.need_kq_pos) {
@ -8023,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
float * data = (float *) lctx.inp_mean->data;
float * data = (float *) lctx.inp_mean->data;
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
std::vector<uint64_t> sum(n_tokens, 0); std::vector<uint64_t> sum(n_tokens, 0);
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0]; const llama_seq_id seq_id = batch.seq_id[i][0];
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
sum[seq_id] += 1; sum[seq_id] += 1;
} }
@ -8051,11 +8128,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
uint32_t * data = (uint32_t *) lctx.inp_cls->data; uint32_t * data = (uint32_t *) lctx.inp_cls->data;
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0]; const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i]; const llama_pos pos = batch.pos[i];
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
if (pos == 0) { if (pos == 0) {
data[seq_id] = i; data[seq_id] = i;
} }
@ -8169,6 +8251,8 @@ static int llama_decode_internal(
batch.seq_id = seq_id_arr.data(); batch.seq_id = seq_id_arr.data();
} }
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx); llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head -> // if we have enough unused cells before the current head ->
@ -8186,6 +8270,7 @@ static int llama_decode_internal(
// if we start defragmenting the cache, the benefit from this will be more important // if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
//kv_self.n = llama_kv_cache_cell_max(kv_self); //kv_self.n = llama_kv_cache_cell_max(kv_self);
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@ -8196,19 +8281,25 @@ static int llama_decode_internal(
// the output is always the last tensor in the graph // the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
if (strcmp(res->name, "result_output") == 0) { if (!hparams.causal_attn) {
// the embeddings could be the second to last tensor, or the third to last tensor res = nullptr; // do not extract logits for embedding models such as BERT
if (strcmp(embeddings->name, "result_norm") != 0) {
embeddings = gf->nodes[gf->n_nodes - 3]; // token or sequence embeddings
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); embd = gf->nodes[gf->n_nodes - 1];
}
} else if (strcmp(res->name, "result_embd") == 0) { GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
embeddings = res;
res = nullptr;
} else { } else {
GGML_ASSERT(false); if (strcmp(res->name, "result_output") == 0) {
// the token embeddings could be the second to last tensor, or the third to last tensor
if (strcmp(embd->name, "result_norm") != 0) {
embd = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
}
} else {
GGML_ASSERT(false && "missing result_output tensor");
}
} }
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@ -8275,46 +8366,82 @@ static int llama_decode_internal(
logits_out.clear(); logits_out.clear();
#endif #endif
ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res); ggml_backend_t backend_res = ggml_backend_sched_get_node_backend(lctx.sched, res);
GGML_ASSERT(res_backend != nullptr); GGML_ASSERT(backend_res != nullptr);
if (batch.logits) { if (batch.logits) {
logits_out.resize(n_vocab * n_tokens); logits_out.resize(n_vocab * n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) { if (batch.logits[i] == 0) {
continue; continue;
} }
ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
#ifndef NDEBUG #ifndef NDEBUG
logits_valid[i] = true; logits_valid[i] = true;
#endif #endif
} }
} else if (lctx.logits_all) { } else if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens); logits_out.resize(n_vocab * n_tokens);
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
#ifndef NDEBUG #ifndef NDEBUG
std::fill(logits_valid.begin(), logits_valid.end(), true); std::fill(logits_valid.begin(), logits_valid.end(), true);
#endif #endif
} else { } else {
logits_out.resize(n_vocab); logits_out.resize(n_vocab);
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float)); ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
#ifndef NDEBUG #ifndef NDEBUG
logits_valid[0] = true; logits_valid[0] = true;
#endif #endif
} }
ggml_backend_synchronize(res_backend); ggml_backend_synchronize(backend_res);
} }
// extract embeddings // extract embeddings
if (!lctx.embedding.empty()) { if (cparams.embeddings && embd) {
auto & embedding_out = lctx.embedding; ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0; switch (cparams.pooling_type) {
const int64_t embd_size = res ? n_embd : n_embd * n_tokens; case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
auto & embd_out = lctx.embd;
embedding_out.resize(embd_size); if (batch.logits) {
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); embd_out.resize(n_embd * n_tokens);
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float)); for (uint32_t i = 0; i < n_tokens; i++) {
ggml_backend_synchronize(embeddings_backend); if (batch.logits[i] == 0) {
continue;
}
ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
}
}
} break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN:
{
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "unknown pooling type");
} break;
}
ggml_backend_synchronize(backend_embd);
} }
// measure the performance only for the single-token evals // measure the performance only for the single-token evals
@ -11864,7 +11991,7 @@ struct llama_context_params llama_context_default_params() {
/*.type_k =*/ GGML_TYPE_F16, /*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16,
/*.logits_all =*/ false, /*.logits_all =*/ false,
/*.embedding =*/ false, /*.embeddings =*/ false,
/*.offload_kqv =*/ true, /*.offload_kqv =*/ true,
/*.abort_callback =*/ nullptr, /*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr, /*.abort_callback_data =*/ nullptr,
@ -12015,6 +12142,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold; cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv; cparams.offload_kqv = params.offload_kqv;
cparams.pooling_type = params.pooling_type; cparams.pooling_type = params.pooling_type;
@ -12192,8 +12320,8 @@ struct llama_context * llama_new_context_with_model(
// resized during inference, reserve maximum // resized during inference, reserve maximum
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
if (params.embedding) { if (params.embeddings) {
ctx->embedding.resize(hparams.n_embd); ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
} }
// graph inputs // graph inputs
@ -12628,7 +12756,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
// assume worst case for logits although only currently set ones are serialized // assume worst case for logits although only currently set ones are serialized
const size_t s_logits = ctx->logits.capacity() * sizeof(float); const size_t s_logits = ctx->logits.capacity() * sizeof(float);
const size_t s_embedding_size = sizeof(size_t); const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embedding.size() * sizeof(float); const size_t s_embedding = ctx->embd.capacity() * sizeof(float);
const size_t s_kv_buf_size = sizeof(size_t); const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t);
@ -12737,12 +12865,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
// copy embeddings // copy embeddings
{ {
const size_t embedding_size = ctx->embedding.size(); const size_t embeddings_size = ctx->embd.size();
data_ctx->write(&embedding_size, sizeof(embedding_size)); data_ctx->write(&embeddings_size, sizeof(embeddings_size));
if (embedding_size) { if (embeddings_size) {
data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
} }
} }
@ -12846,15 +12974,17 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
// set embeddings // set embeddings
{ {
size_t embedding_size; size_t embeddings_size;
memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
GGML_ASSERT(ctx->embedding.capacity() == embedding_size); GGML_ASSERT(ctx->embd.capacity() == embeddings_size);
if (embedding_size) { if (embeddings_size) {
memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); ctx->embd.resize(embeddings_size);
inp += embedding_size * sizeof(float);
memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
inp += embeddings_size * sizeof(float);
} }
} }
@ -13104,11 +13234,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
} }
float * llama_get_embeddings(struct llama_context * ctx) { float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data(); return ctx->embd.data();
} }
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
return ctx->embedding.data() + i*ctx->model.hparams.n_embd; return ctx->embd.data() + i*ctx->model.hparams.n_embd;
}
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
auto it = ctx->embd_seq.find(seq_id);
if (it == ctx->embd_seq.end()) {
return nullptr;
}
return it->second.data();
} }
const char * llama_token_get_text(const struct llama_model * model, llama_token token) { const char * llama_token_get_text(const struct llama_model * model, llama_token token) {

18
llama.h
View File

@ -163,7 +163,7 @@ extern "C" {
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence // - pos : the positions of the respective token in the sequence
// - seq_id : the sequence to which the respective token belongs // - seq_id : the sequence to which the respective token belongs
// - logits : if zero, the logits for the respective token will not be output // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// //
typedef struct llama_batch { typedef struct llama_batch {
int32_t n_tokens; int32_t n_tokens;
@ -173,7 +173,7 @@ extern "C" {
llama_pos * pos; llama_pos * pos;
int32_t * n_seq_id; int32_t * n_seq_id;
llama_seq_id ** seq_id; llama_seq_id ** seq_id;
int8_t * logits; int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future // NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below // for future-proof code, use the above fields instead and ignore everything below
@ -260,7 +260,7 @@ extern "C" {
// Keep the booleans together to avoid misalignment during copy-by-value. // Keep the booleans together to avoid misalignment during copy-by-value.
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
// Abort callback // Abort callback
@ -655,14 +655,20 @@ extern "C" {
// llama_get_logits(ctx) + i*n_vocab // llama_get_logits(ctx) + i*n_vocab
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for the input // Get all output token embeddings
// shape: [n_embd] (1-dimensional) // shape: [n_tokens*n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith sequence // Get the embeddings for the ith token
// llama_get_embeddings(ctx) + i*n_embd // llama_get_embeddings(ctx) + i*n_embd
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
// //
// Vocab // Vocab
// //