context : add get_ctx_padding()

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-17 11:51:35 +02:00
parent 12719550a6
commit 587384e8f2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 10 additions and 1 deletions

View File

@ -64,6 +64,10 @@ llama_pos llama_context::pos_max() const {
return kv_self.pos_max(); return kv_self.pos_max();
} }
uint32_t llama_context::get_ctx_padding(const llama_cparams & cparams) const {
return kv_self.get_padding(cparams);
}
// TODO: improve // TODO: improve
void llama_context::reset() { void llama_context::reset() {
inp_tokens = nullptr; inp_tokens = nullptr;

View File

@ -84,8 +84,11 @@ struct llama_context {
ggml_cgraph * graph, ggml_cgraph * graph,
bool batched); bool batched);
// max token position across all sequences in the current context
llama_pos pos_max() const; llama_pos pos_max() const;
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
void reset(); void reset();
void prepare_k_shift(); void prepare_k_shift();

View File

@ -7820,6 +7820,7 @@ static int llama_decode_impl(
} }
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
@ -8154,6 +8155,7 @@ static int llama_encode_impl(
} }
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
@ -8619,7 +8621,7 @@ struct llama_context * llama_init_from_model(
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
// this is necessary due to kv_self.n being padded later during inference // this is necessary due to kv_self.n being padded later during inference
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->kv_self.get_padding(cparams)); cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->get_ctx_padding(cparams));
// with causal attention, the batch size is limited by the context size // with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;