allow to toggle embedding mode

This commit is contained in:
Douglas Hanley 2024-03-07 11:55:27 -06:00
parent f618e5060a
commit bd3d9fbfed
5 changed files with 21 additions and 12 deletions

View File

@ -2,7 +2,7 @@
BUILD_TARGETS = \ BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
# Binaries only useful for tests # Binaries only useful for tests
TEST_TARGETS = \ TEST_TARGETS = \

View File

@ -1304,7 +1304,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.pooling_type = params.pooling_type; cparams.pooling_type = params.pooling_type;
cparams.defrag_thold = params.defrag_thold; cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.causal_attn = !params.embedding;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v); cparams.type_v = kv_cache_type_from_str(params.cache_type_v);

View File

@ -125,15 +125,14 @@ int main(int argc, char* argv[])
return true; return true;
}; };
cparams.embeddings = true;
cparams.causal_attn = false;
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
llama_backend_init(); llama_backend_init();
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams); auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
auto ctx = llama_new_context_with_model(mdl, cparams); auto ctx = llama_new_context_with_model(mdl, cparams);
// set to embedding mode
llama_set_embeddings(ctx, true);
// ### Embedding/Representation ### taken sample from here: // ### Embedding/Representation ### taken sample from here:
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic // https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
{ {

View File

@ -1684,7 +1684,6 @@ struct llama_cparams {
bool embeddings; bool embeddings;
bool offload_kqv; bool offload_kqv;
bool causal_attn;
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;
@ -8030,7 +8029,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
} }
if (cparams.causal_attn) { GGML_ASSERT(
(hparams.causal_attn || cparams.embeddings) &&
"non-causal attention with generative models is not supported"
);
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
// But if cparams.embeddings is set, the attention will be non-causal nonetheless.
if (!cparams.embeddings) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
@ -8055,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
} else { } else {
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used) // with causal attention, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
@ -11998,7 +12004,6 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false, /*.logits_all =*/ false,
/*.embeddings =*/ false, /*.embeddings =*/ false,
/*.offload_kqv =*/ true, /*.offload_kqv =*/ true,
/*.causal_attn =*/ true,
/*.abort_callback =*/ nullptr, /*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr, /*.abort_callback_data =*/ nullptr,
}; };
@ -12150,7 +12155,6 @@ struct llama_context * llama_new_context_with_model(
cparams.defrag_thold = params.defrag_thold; cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings; cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv; cparams.offload_kqv = params.offload_kqv;
cparams.causal_attn = params.causal_attn;
cparams.pooling_type = params.pooling_type; cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
@ -13165,6 +13169,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data; ctx->abort_callback_data = abort_callback_data;
} }
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
ctx->cparams.embeddings = embeddings;
}
struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_get_one(
llama_token * tokens, llama_token * tokens,
int32_t n_tokens, int32_t n_tokens,

View File

@ -262,7 +262,6 @@ extern "C" {
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits) bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool causal_attn; // whether to use causal attention
// Abort callback // Abort callback
// if it returns true, execution of llama_decode() will be aborted // if it returns true, execution of llama_decode() will be aborted
@ -642,6 +641,10 @@ extern "C" {
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
// Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
// Set abort callback // Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);