diff --git a/Makefile b/Makefile index 64a2d5bad..223d37eb4 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BUILD_TARGETS = \ main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ - speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o + speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o # Binaries only useful for tests TEST_TARGETS = \ diff --git a/common/common.cpp b/common/common.cpp index d8baf7782..c244db644 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1304,7 +1304,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.pooling_type = params.pooling_type; cparams.defrag_thold = params.defrag_thold; cparams.offload_kqv = !params.no_kv_offload; - cparams.causal_attn = !params.embedding; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index d06f64e24..7495e2894 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -125,15 +125,14 @@ int main(int argc, char* argv[]) return true; }; - cparams.embeddings = true; - cparams.causal_attn = false; - cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; - llama_backend_init(); auto mdl = llama_load_model_from_file(params.model.c_str(), mparams); auto ctx = llama_new_context_with_model(mdl, cparams); + // set to embedding mode + llama_set_embeddings(ctx, true); + // ### Embedding/Representation ### taken sample from here: // https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic { diff --git a/llama.cpp b/llama.cpp index 04816ea9e..79171c749 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1684,7 +1684,6 @@ struct llama_cparams { bool embeddings; bool offload_kqv; - bool causal_attn; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; @@ -8030,7 +8029,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - if (cparams.causal_attn) { + GGML_ASSERT( + (hparams.causal_attn || cparams.embeddings) && + "non-causal attention with generative models is not supported" + ); + + // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + // But if cparams.embeddings is set, the attention will be non-causal nonetheless. + if (!cparams.embeddings) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -8055,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } else { - // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used) + // with causal attention, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; @@ -11998,7 +12004,6 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, - /*.causal_attn =*/ true, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -12150,7 +12155,6 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; - cparams.causal_attn = params.causal_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -13165,6 +13169,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback) ctx->abort_callback_data = abort_callback_data; } +void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { + ctx->cparams.embeddings = embeddings; +} + struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens, diff --git a/llama.h b/llama.h index 6265d6901..0fe7b0105 100644 --- a/llama.h +++ b/llama.h @@ -262,7 +262,6 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - bool causal_attn; // whether to use causal attention // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -642,6 +641,10 @@ extern "C" { // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + // Set whether to use causal attention or not + // If set to true, the model will only attend to the past tokens + LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); + // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);