mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 12:57:03 +01:00
allow to toggle embedding mode
This commit is contained in:
parent
f618e5060a
commit
bd3d9fbfed
2
Makefile
2
Makefile
@ -2,7 +2,7 @@
|
|||||||
BUILD_TARGETS = \
|
BUILD_TARGETS = \
|
||||||
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
||||||
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
|
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
|
||||||
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
|
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
|
||||||
|
|
||||||
# Binaries only useful for tests
|
# Binaries only useful for tests
|
||||||
TEST_TARGETS = \
|
TEST_TARGETS = \
|
||||||
|
@ -1304,7 +1304,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||||||
cparams.pooling_type = params.pooling_type;
|
cparams.pooling_type = params.pooling_type;
|
||||||
cparams.defrag_thold = params.defrag_thold;
|
cparams.defrag_thold = params.defrag_thold;
|
||||||
cparams.offload_kqv = !params.no_kv_offload;
|
cparams.offload_kqv = !params.no_kv_offload;
|
||||||
cparams.causal_attn = !params.embedding;
|
|
||||||
|
|
||||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||||
|
@ -125,15 +125,14 @@ int main(int argc, char* argv[])
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
cparams.embeddings = true;
|
|
||||||
cparams.causal_attn = false;
|
|
||||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
|
|
||||||
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||||
auto ctx = llama_new_context_with_model(mdl, cparams);
|
auto ctx = llama_new_context_with_model(mdl, cparams);
|
||||||
|
|
||||||
|
// set to embedding mode
|
||||||
|
llama_set_embeddings(ctx, true);
|
||||||
|
|
||||||
// ### Embedding/Representation ### taken sample from here:
|
// ### Embedding/Representation ### taken sample from here:
|
||||||
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
|
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
|
||||||
{
|
{
|
||||||
|
18
llama.cpp
18
llama.cpp
@ -1684,7 +1684,6 @@ struct llama_cparams {
|
|||||||
|
|
||||||
bool embeddings;
|
bool embeddings;
|
||||||
bool offload_kqv;
|
bool offload_kqv;
|
||||||
bool causal_attn;
|
|
||||||
enum llama_pooling_type pooling_type;
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval;
|
ggml_backend_sched_eval_callback cb_eval;
|
||||||
@ -8030,7 +8029,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cparams.causal_attn) {
|
GGML_ASSERT(
|
||||||
|
(hparams.causal_attn || cparams.embeddings) &&
|
||||||
|
"non-causal attention with generative models is not supported"
|
||||||
|
);
|
||||||
|
|
||||||
|
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||||
|
// But if cparams.embeddings is set, the attention will be non-causal nonetheless.
|
||||||
|
if (!cparams.embeddings) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
@ -8055,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
|
// with causal attention, the mask needs to match the kv cache size
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
|
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
|
||||||
|
|
||||||
@ -11998,7 +12004,6 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.embeddings =*/ false,
|
/*.embeddings =*/ false,
|
||||||
/*.offload_kqv =*/ true,
|
/*.offload_kqv =*/ true,
|
||||||
/*.causal_attn =*/ true,
|
|
||||||
/*.abort_callback =*/ nullptr,
|
/*.abort_callback =*/ nullptr,
|
||||||
/*.abort_callback_data =*/ nullptr,
|
/*.abort_callback_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
@ -12150,7 +12155,6 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
cparams.defrag_thold = params.defrag_thold;
|
cparams.defrag_thold = params.defrag_thold;
|
||||||
cparams.embeddings = params.embeddings;
|
cparams.embeddings = params.embeddings;
|
||||||
cparams.offload_kqv = params.offload_kqv;
|
cparams.offload_kqv = params.offload_kqv;
|
||||||
cparams.causal_attn = params.causal_attn;
|
|
||||||
cparams.pooling_type = params.pooling_type;
|
cparams.pooling_type = params.pooling_type;
|
||||||
|
|
||||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||||
@ -13165,6 +13169,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
|
|||||||
ctx->abort_callback_data = abort_callback_data;
|
ctx->abort_callback_data = abort_callback_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
|
||||||
|
ctx->cparams.embeddings = embeddings;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_batch llama_batch_get_one(
|
struct llama_batch llama_batch_get_one(
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
|
5
llama.h
5
llama.h
@ -262,7 +262,6 @@ extern "C" {
|
|||||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||||
bool embeddings; // if true, extract embeddings (together with logits)
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
bool causal_attn; // whether to use causal attention
|
|
||||||
|
|
||||||
// Abort callback
|
// Abort callback
|
||||||
// if it returns true, execution of llama_decode() will be aborted
|
// if it returns true, execution of llama_decode() will be aborted
|
||||||
@ -642,6 +641,10 @@ extern "C" {
|
|||||||
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
||||||
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
|
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
|
||||||
|
|
||||||
|
// Set whether to use causal attention or not
|
||||||
|
// If set to true, the model will only attend to the past tokens
|
||||||
|
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
||||||
|
|
||||||
// Set abort callback
|
// Set abort callback
|
||||||
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user