mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 14:28:58 +01:00
allow to toggle embedding mode
This commit is contained in:
parent
f618e5060a
commit
bd3d9fbfed
2
Makefile
2
Makefile
@ -2,7 +2,7 @@
|
||||
BUILD_TARGETS = \
|
||||
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
||||
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
|
||||
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
|
||||
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
|
||||
|
||||
# Binaries only useful for tests
|
||||
TEST_TARGETS = \
|
||||
|
@ -1304,7 +1304,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
cparams.causal_attn = !params.embedding;
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||
|
@ -125,15 +125,14 @@ int main(int argc, char* argv[])
|
||||
return true;
|
||||
};
|
||||
|
||||
cparams.embeddings = true;
|
||||
cparams.causal_attn = false;
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
auto ctx = llama_new_context_with_model(mdl, cparams);
|
||||
|
||||
// set to embedding mode
|
||||
llama_set_embeddings(ctx, true);
|
||||
|
||||
// ### Embedding/Representation ### taken sample from here:
|
||||
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
|
||||
{
|
||||
|
18
llama.cpp
18
llama.cpp
@ -1684,7 +1684,6 @@ struct llama_cparams {
|
||||
|
||||
bool embeddings;
|
||||
bool offload_kqv;
|
||||
bool causal_attn;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
@ -8030,7 +8029,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
if (cparams.causal_attn) {
|
||||
GGML_ASSERT(
|
||||
(hparams.causal_attn || cparams.embeddings) &&
|
||||
"non-causal attention with generative models is not supported"
|
||||
);
|
||||
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
// But if cparams.embeddings is set, the attention will be non-causal nonetheless.
|
||||
if (!cparams.embeddings) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
@ -8055,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
|
||||
// with causal attention, the mask needs to match the kv cache size
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
|
||||
|
||||
@ -11998,7 +12004,6 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.logits_all =*/ false,
|
||||
/*.embeddings =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.causal_attn =*/ true,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
};
|
||||
@ -12150,7 +12155,6 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.causal_attn = params.causal_attn;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
@ -13165,6 +13169,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
|
||||
ctx->abort_callback_data = abort_callback_data;
|
||||
}
|
||||
|
||||
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
|
||||
ctx->cparams.embeddings = embeddings;
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
|
5
llama.h
5
llama.h
@ -262,7 +262,6 @@ extern "C" {
|
||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool causal_attn; // whether to use causal attention
|
||||
|
||||
// Abort callback
|
||||
// if it returns true, execution of llama_decode() will be aborted
|
||||
@ -642,6 +641,10 @@ extern "C" {
|
||||
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
||||
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
|
||||
|
||||
// Set whether to use causal attention or not
|
||||
// If set to true, the model will only attend to the past tokens
|
||||
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
||||
|
||||
// Set abort callback
|
||||
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user