From 587384e8f20c1d92a845586bc94de072e7e10528 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 17 Jan 2025 11:51:35 +0200 Subject: [PATCH] context : add get_ctx_padding() ggml-ci --- src/llama-context.cpp | 4 ++++ src/llama-context.h | 3 +++ src/llama.cpp | 4 +++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index daea125fe..6a73659d0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -64,6 +64,10 @@ llama_pos llama_context::pos_max() const { return kv_self.pos_max(); } +uint32_t llama_context::get_ctx_padding(const llama_cparams & cparams) const { + return kv_self.get_padding(cparams); +} + // TODO: improve void llama_context::reset() { inp_tokens = nullptr; diff --git a/src/llama-context.h b/src/llama-context.h index bc33fc6ef..45eaafaad 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -84,8 +84,11 @@ struct llama_context { ggml_cgraph * graph, bool batched); + // max token position across all sequences in the current context llama_pos pos_max() const; + uint32_t get_ctx_padding(const llama_cparams & cparams) const; + void reset(); void prepare_k_shift(); diff --git a/src/llama.cpp b/src/llama.cpp index 2b2d8f4b1..59cc292b5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7820,6 +7820,7 @@ static int llama_decode_impl( } // temporary allocate memory for the input batch if needed + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1); const llama_batch & batch = batch_allocr.batch; @@ -8154,6 +8155,7 @@ static int llama_encode_impl( } // temporary allocate memory for the input batch if needed + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1); const llama_batch & batch = batch_allocr.batch; @@ -8619,7 +8621,7 @@ struct llama_context * llama_init_from_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->kv_self.get_padding(cparams)); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->get_ctx_padding(cparams)); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;