context : add get_ctx_padding()

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-17 11:51:35 +02:00
parent 12719550a6
commit 587384e8f2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 10 additions and 1 deletions

View File

@ -64,6 +64,10 @@ llama_pos llama_context::pos_max() const {
return kv_self.pos_max();
}
uint32_t llama_context::get_ctx_padding(const llama_cparams & cparams) const {
return kv_self.get_padding(cparams);
}
// TODO: improve
void llama_context::reset() {
inp_tokens = nullptr;

View File

@ -84,8 +84,11 @@ struct llama_context {
ggml_cgraph * graph,
bool batched);
// max token position across all sequences in the current context
llama_pos pos_max() const;
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
void reset();
void prepare_k_shift();

View File

@ -7820,6 +7820,7 @@ static int llama_decode_impl(
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
@ -8154,6 +8155,7 @@ static int llama_encode_impl(
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
@ -8619,7 +8621,7 @@ struct llama_context * llama_init_from_model(
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
// this is necessary due to kv_self.n being padded later during inference
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->kv_self.get_padding(cparams));
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->get_ctx_padding(cparams));
// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;