mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-04 15:43:53 +01:00
llama : add llama_model_is_recurrent to simplify figuring that out
This will make it easier to more cleanly support RWKV-v6 and Mamba-2.
This commit is contained in:
parent
b264eddbb2
commit
1be5ea7d97
@ -508,6 +508,9 @@ extern "C" {
|
|||||||
// to the decoder to start generating output sequence. For other models, it returns -1.
|
// to the decoder to start generating output sequence. For other models, it returns -1.
|
||||||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
|
||||||
|
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
|
||||||
|
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
LLAMA_API uint32_t llama_model_quantize(
|
LLAMA_API uint32_t llama_model_quantize(
|
||||||
const char * fname_inp,
|
const char * fname_inp,
|
||||||
|
@ -3292,8 +3292,7 @@ static bool llama_kv_cache_init(
|
|||||||
|
|
||||||
cache.has_shift = false;
|
cache.has_shift = false;
|
||||||
|
|
||||||
// TODO: find a nicer way to add other recurrent model architectures
|
cache.recurrent = llama_model_is_recurrent(&model);
|
||||||
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
|
|
||||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
||||||
|
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
@ -17235,7 +17234,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
ggml_type type_v = params.type_v;
|
ggml_type type_v = params.type_v;
|
||||||
|
|
||||||
// Mamba only needs a constant number of KV cache cells per sequence
|
// Mamba only needs a constant number of KV cache cells per sequence
|
||||||
if (model->arch == LLM_ARCH_MAMBA) {
|
if (llama_model_is_recurrent(model)) {
|
||||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||||
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||||
// it's probably best to keep as much precision as possible for the states
|
// it's probably best to keep as much precision as possible for the states
|
||||||
@ -17709,6 +17708,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
|
|||||||
return model->hparams.dec_start_token_id;
|
return model->hparams.dec_start_token_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_model_is_recurrent(const struct llama_model * model) {
|
||||||
|
switch (model->arch) {
|
||||||
|
case LLM_ARCH_MAMBA: return true;
|
||||||
|
default: return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_model_quantize(
|
uint32_t llama_model_quantize(
|
||||||
const char * fname_inp,
|
const char * fname_inp,
|
||||||
const char * fname_out,
|
const char * fname_out,
|
||||||
|
Loading…
Reference in New Issue
Block a user