mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-30 22:03:03 +01:00
context : add get_ctx_padding()
ggml-ci
This commit is contained in:
parent
12719550a6
commit
587384e8f2
@ -64,6 +64,10 @@ llama_pos llama_context::pos_max() const {
|
||||
return kv_self.pos_max();
|
||||
}
|
||||
|
||||
uint32_t llama_context::get_ctx_padding(const llama_cparams & cparams) const {
|
||||
return kv_self.get_padding(cparams);
|
||||
}
|
||||
|
||||
// TODO: improve
|
||||
void llama_context::reset() {
|
||||
inp_tokens = nullptr;
|
||||
|
@ -84,8 +84,11 @@ struct llama_context {
|
||||
ggml_cgraph * graph,
|
||||
bool batched);
|
||||
|
||||
// max token position across all sequences in the current context
|
||||
llama_pos pos_max() const;
|
||||
|
||||
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
||||
|
||||
void reset();
|
||||
|
||||
void prepare_k_shift();
|
||||
|
@ -7820,6 +7820,7 @@ static int llama_decode_impl(
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
@ -8154,6 +8155,7 @@ static int llama_encode_impl(
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
@ -8619,7 +8621,7 @@ struct llama_context * llama_init_from_model(
|
||||
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||
|
||||
// this is necessary due to kv_self.n being padded later during inference
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->kv_self.get_padding(cparams));
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->get_ctx_padding(cparams));
|
||||
|
||||
// with causal attention, the batch size is limited by the context size
|
||||
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||
|
Loading…
Reference in New Issue
Block a user