From 1be5ea7d97a2b470a2bd1ce2e764a02bdec72a83 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 20 Aug 2024 23:55:14 -0400 Subject: [PATCH] llama : add llama_model_is_recurrent to simplify figuring that out This will make it easier to more cleanly support RWKV-v6 and Mamba-2. --- include/llama.h | 3 +++ src/llama.cpp | 12 +++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/include/llama.h b/include/llama.h index 3c28cf0b5..64fea5583 100644 --- a/include/llama.h +++ b/include/llama.h @@ -508,6 +508,9 @@ extern "C" { // to the decoder to start generating output sequence. For other models, it returns -1. LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); + // Returns true if the model is recurrent (like Mamba, RWKV, etc.) + LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama.cpp b/src/llama.cpp index bd319e62c..8b2108fd0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3292,8 +3292,7 @@ static bool llama_kv_cache_init( cache.has_shift = false; - // TODO: find a nicer way to add other recurrent model architectures - cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.recurrent = llama_model_is_recurrent(&model); cache.v_trans = !cache.recurrent && !cparams.flash_attn; cache.head = 0; @@ -17235,7 +17234,7 @@ struct llama_context * llama_new_context_with_model( ggml_type type_v = params.type_v; // Mamba only needs a constant number of KV cache cells per sequence - if (model->arch == LLM_ARCH_MAMBA) { + if (llama_model_is_recurrent(model)) { // Mamba needs at least as many KV cells as there are sequences kept at any time kv_size = std::max((uint32_t) 1, params.n_seq_max); // it's probably best to keep as much precision as possible for the states @@ -17709,6 +17708,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { return model->hparams.dec_start_token_id; } +bool llama_model_is_recurrent(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_MAMBA: return true; + default: return false; + } +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out,