mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 04:23:06 +01:00
parent
a88ad007de
commit
8e752a777b
@ -875,6 +875,12 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
|
||||||
|
LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
|
||||||
|
llama_free_model(model);
|
||||||
|
return iparams;
|
||||||
|
}
|
||||||
|
|
||||||
if (!params.control_vectors.empty()) {
|
if (!params.control_vectors.empty()) {
|
||||||
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
|
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
|
||||||
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
|
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
|
||||||
|
@ -667,6 +667,9 @@ extern "C" {
|
|||||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||||
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Check if the context supports KV cache shifting
|
||||||
|
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
|
||||||
|
|
||||||
//
|
//
|
||||||
// State / sessions
|
// State / sessions
|
||||||
//
|
//
|
||||||
|
@ -18213,7 +18213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||||||
|
|
||||||
// apply K-shift if needed
|
// apply K-shift if needed
|
||||||
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||||
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
|
if (!llama_kv_cache_can_shift(&lctx)) {
|
||||||
GGML_ABORT("Deepseek2 does not support K-shift");
|
GGML_ABORT("Deepseek2 does not support K-shift");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20462,6 +20462,10 @@ void llama_kv_cache_update(struct llama_context * ctx) {
|
|||||||
llama_kv_cache_update_internal(*ctx);
|
llama_kv_cache_update_internal(*ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_kv_cache_can_shift(struct llama_context * ctx) {
|
||||||
|
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
|
||||||
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
size_t llama_get_state_size(struct llama_context * ctx) {
|
size_t llama_get_state_size(struct llama_context * ctx) {
|
||||||
return llama_state_get_size(ctx);
|
return llama_state_get_size(ctx);
|
||||||
|
Loading…
Reference in New Issue
Block a user