mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 13:00:25 +01:00
llama : support batched embeddings (#5466)
* batched embedding: pool outputs by sequence id. updated embedding example * bring back non-causal attention * embd : minor improvements * llama : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
ad014bba97
commit
03bf161eb6
@ -1648,6 +1648,7 @@ class BertModel(Model):
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_pooling_layer(True)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def set_vocab(self):
|
||||
|
@ -7,6 +7,51 @@
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static std::vector<std::string> split_lines(const std::string & s) {
|
||||
std::string line;
|
||||
std::vector<std::string> lines;
|
||||
std::stringstream ss(s);
|
||||
while (std::getline(ss, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
return lines;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
|
||||
}
|
||||
}
|
||||
|
||||
static void normalize(float * vec, float * out, int n) {
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
norm += vec[i] * vec[i];
|
||||
}
|
||||
norm = sqrt(norm);
|
||||
for (int i = 0; i < n; i++) {
|
||||
out[i] = vec[i] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// run model
|
||||
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
if (llama_decode(ctx, batch) < 0) {
|
||||
fprintf(stderr, "%s : failed to decode\n", __func__);
|
||||
}
|
||||
|
||||
// normalize on copy
|
||||
for (int k = 0; k < n_seq; k++) {
|
||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||
float * out = output + k * n_embd;
|
||||
normalize(emb, out, n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
@ -55,59 +100,84 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s\n", get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
int n_past = 0;
|
||||
// split the prompt into lines
|
||||
std::vector<std::string> prompts = split_lines(params.prompt);
|
||||
|
||||
// tokenize the prompt
|
||||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||
// max batch size
|
||||
const uint64_t n_batch = params.n_batch;
|
||||
GGML_ASSERT(params.n_batch == params.n_ctx);
|
||||
|
||||
// tokenize the prompts and trim
|
||||
std::vector<std::vector<int32_t>> inputs;
|
||||
for (const auto & prompt : prompts) {
|
||||
auto inp = ::llama_tokenize(ctx, prompt, true);
|
||||
if (inp.size() > n_batch) {
|
||||
inp.resize(n_batch);
|
||||
}
|
||||
inputs.push_back(inp);
|
||||
}
|
||||
|
||||
// tokenization stats
|
||||
if (params.verbose_prompt) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
||||
for (int i = 0; i < (int) inputs.size(); i++) {
|
||||
fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
|
||||
for (int j = 0; j < (int) inputs[i].size(); j++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str());
|
||||
}
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
if (embd_inp.size() > (size_t)n_ctx) {
|
||||
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
|
||||
__func__, embd_inp.size(), n_ctx);
|
||||
return 1;
|
||||
}
|
||||
|
||||
while (!embd_inp.empty()) {
|
||||
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
n_past += n_tokens;
|
||||
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
|
||||
}
|
||||
// initialize batch
|
||||
const int n_prompts = prompts.size();
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_n_embd(model);
|
||||
auto * embeddings = llama_get_embeddings(ctx);
|
||||
std::vector<float> embeddings(n_prompts * n_embd, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
// l2-normalize embeddings
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
norm += embeddings[i] * embeddings[i];
|
||||
}
|
||||
norm = sqrt(norm);
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
embeddings[i] /= norm;
|
||||
// break into batches
|
||||
int p = 0; // number of prompts processed already
|
||||
int s = 0; // number of prompts in current batch
|
||||
for (int k = 0; k < n_prompts; k++) {
|
||||
// clamp to n_batch tokens
|
||||
auto & inp = inputs[k];
|
||||
const uint64_t n_toks = inp.size();
|
||||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch) {
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
llama_batch_clear(batch);
|
||||
p += s;
|
||||
s = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
printf("%f ", embeddings[i]);
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp, s);
|
||||
s += 1;
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
// final batch
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
|
||||
// print first 3 embeddings
|
||||
for (int j = 0; j < std::min(3, n_prompts); j++) {
|
||||
fprintf(stderr, "embedding %d: ", j);
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
fprintf(stderr, "%f ", emb[j * n_embd + i]);
|
||||
}
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// clean up
|
||||
llama_print_timings(ctx);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
|
@ -40,6 +40,7 @@ class Keys:
|
||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||
EXPERT_COUNT = "{arch}.expert_count"
|
||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
POOLING_LAYER = "{arch}.pooling_layer"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
|
@ -360,6 +360,9 @@ class GGUFWriter:
|
||||
def add_causal_attention(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||
|
||||
def add_pooling_layer(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.POOLING_LAYER.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
||||
|
59
llama.cpp
59
llama.cpp
@ -254,6 +254,7 @@ enum llm_kv {
|
||||
LLM_KV_TENSOR_DATA_LAYOUT,
|
||||
LLM_KV_EXPERT_COUNT,
|
||||
LLM_KV_EXPERT_USED_COUNT,
|
||||
LLM_KV_POOLING_LAYER,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
@ -311,6 +312,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
||||
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||
{ LLM_KV_POOLING_LAYER, "%s.pooling_layer" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
@ -1539,6 +1541,7 @@ struct llama_hparams {
|
||||
float f_max_alibi_bias;
|
||||
|
||||
bool causal_attn = true;
|
||||
bool pooling_layer = false;
|
||||
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
@ -1601,6 +1604,7 @@ struct llama_cparams {
|
||||
|
||||
bool mul_mat_q;
|
||||
bool offload_kqv;
|
||||
bool do_pooling;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
@ -1896,7 +1900,7 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
|
||||
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
|
||||
struct ggml_tensor * inp_sum; // F32 [1, n_batch]
|
||||
struct ggml_tensor * inp_sum; // F32 [n_batch, n_batch]
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_context * ctx_mpi = NULL;
|
||||
@ -3053,6 +3057,7 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
||||
ml.get_key(LLM_KV_POOLING_LAYER, hparams.pooling_layer);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 3:
|
||||
@ -4859,7 +4864,7 @@ struct llm_build_context {
|
||||
const int32_t n_orig_ctx;
|
||||
|
||||
const bool do_rope_shift;
|
||||
const bool causal_attn;
|
||||
const bool do_pooling;
|
||||
|
||||
const llm_build_cb & cb;
|
||||
|
||||
@ -4903,7 +4908,7 @@ struct llm_build_context {
|
||||
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
|
||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||
do_rope_shift (worst_case || kv_self.has_shift),
|
||||
causal_attn (hparams.causal_attn),
|
||||
do_pooling (hparams.pooling_layer && cparams.do_pooling),
|
||||
cb (cb),
|
||||
buf_compute_meta (lctx.buf_compute_meta) {
|
||||
// all initializations should be done in init()
|
||||
@ -5752,17 +5757,18 @@ struct llm_build_context {
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// get input vectors with right size
|
||||
const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
|
||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||
struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
|
||||
struct ggml_tensor * inp_sum = ggml_view_2d(ctx0, lctx.inp_sum, n_tokens, n_tokens, stride1, 0);
|
||||
|
||||
// construct input embeddings (token, type, position)
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
|
||||
// token types are hardcoded to zero ("Sentence A")
|
||||
struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
|
||||
inpL = ggml_add(ctx0, inpL, type_row0);
|
||||
@ -5832,9 +5838,11 @@ struct llm_build_context {
|
||||
// final output
|
||||
cur = inpL;
|
||||
|
||||
// pooling
|
||||
cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
|
||||
cb(cur, "result_embed", -1);
|
||||
// pooling layer
|
||||
if (do_pooling) {
|
||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum);
|
||||
}
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
@ -7367,7 +7375,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
|
||||
(hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
f = 0;
|
||||
@ -7378,7 +7387,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
{
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
|
||||
float * data = (float *) lctx.inp_sum->data;
|
||||
@ -7399,6 +7407,20 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
data[i] = lctx.kv_self.cells[i].delta;
|
||||
}
|
||||
}
|
||||
|
||||
if (hparams.pooling_layer && cparams.do_pooling) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
|
||||
float * data = (float *) lctx.inp_sum->data;
|
||||
|
||||
memset(lctx.inp_sum->data, 0, batch.n_tokens * batch.n_tokens * ggml_element_size(lctx.inp_sum));
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
data[seq_id*n_tokens + i] = 1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
@ -7510,7 +7532,7 @@ static int llama_decode_internal(
|
||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||
}
|
||||
} else if (strcmp(res->name, "result_embed") == 0) {
|
||||
} else if (strcmp(res->name, "result_embd") == 0) {
|
||||
embeddings = res;
|
||||
res = nullptr;
|
||||
} else {
|
||||
@ -7630,11 +7652,12 @@ static int llama_decode_internal(
|
||||
if (!lctx.embedding.empty()) {
|
||||
auto & embedding_out = lctx.embedding;
|
||||
|
||||
const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
|
||||
const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0;
|
||||
const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
|
||||
|
||||
embedding_out.resize(n_embd);
|
||||
embedding_out.resize(embd_size);
|
||||
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
|
||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
|
||||
ggml_backend_synchronize(embeddings_backend);
|
||||
}
|
||||
|
||||
@ -10950,6 +10973,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.logits_all =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.do_pooling =*/ true,
|
||||
};
|
||||
|
||||
return result;
|
||||
@ -11105,6 +11129,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.do_pooling = params.do_pooling;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
@ -11270,7 +11295,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
|
||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
|
||||
ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch);
|
||||
ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
||||
|
||||
ggml_set_name(ctx->inp_tokens, "inp_tokens");
|
||||
ggml_set_name(ctx->inp_embd, "inp_embd");
|
||||
@ -12128,6 +12153,10 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
return ctx->embedding.data();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||
return ctx->embedding.data() + i*ctx->model.hparams.n_embd;
|
||||
}
|
||||
|
||||
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
||||
return model->vocab.id_to_token[token].text.c_str();
|
||||
}
|
||||
|
5
llama.h
5
llama.h
@ -236,6 +236,7 @@ extern "C" {
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
bool embedding; // embedding mode only
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@ -628,6 +629,10 @@ extern "C" {
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Get the embeddings for the ith sequence
|
||||
// llama_get_embeddings(ctx) + i*n_embd
|
||||
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
Loading…
x
Reference in New Issue
Block a user