context : initial need_reserve logic

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-17 14:42:09 +02:00
parent 2f6d767fc5
commit 1f7dc12222
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 269 additions and 245 deletions

View File

@ -576,9 +576,7 @@ ggml_tensor * llama_context::build_lora_mm_id(
return res;
}
bool llama_context::kv_self_update() {
bool need_reserve = false;
void llama_context::kv_self_update() {
auto & kv = kv_self;
if (kv.has_shift) {
@ -655,12 +653,14 @@ bool llama_context::kv_self_update() {
ggml_free(ctx0);
need_reserve = true;
kv.do_defrag = false;
}
return need_reserve;
need_reserve = true;
}
}
void llama_kv_self_update(llama_context * ctx) {
ctx->kv_self_update();
}
void llama_context::build_attn_inp(
@ -1824,6 +1824,165 @@ int32_t llama_apply_adapter_cvec(
return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
}
//
// kv cache view
//
struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
}
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
llama_kv_cache_view_update(view, ctx->kv_self);
}
//
// kv cache
//
// deprecated
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
return llama_kv_self_n_tokens(ctx);
}
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
return llama_kv_cache_n_tokens(&ctx->kv_self);
}
// deprecated
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
return llama_kv_self_used_cells(ctx);
}
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
return llama_kv_cache_used_cells(&ctx->kv_self);
}
// deprecated
void llama_kv_cache_clear(llama_context * ctx) {
llama_kv_self_clear(ctx);
}
void llama_kv_self_clear(llama_context * ctx) {
llama_kv_cache_clear(&ctx->kv_self);
}
// deprecated
bool llama_kv_cache_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
}
bool llama_kv_self_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
}
// deprecated
void llama_kv_cache_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_self_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
// deprecated
void llama_kv_cache_seq_keep(
llama_context * ctx,
llama_seq_id seq_id) {
return llama_kv_self_seq_keep(ctx, seq_id);
}
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
}
// deprecated
void llama_kv_cache_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
}
void llama_kv_self_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
}
// deprecated
void llama_kv_cache_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
}
void llama_kv_self_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
}
// deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id);
}
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
}
// deprecated
void llama_kv_cache_defrag(llama_context * ctx) {
return llama_kv_self_defrag(ctx);
}
void llama_kv_self_defrag(llama_context * ctx) {
return llama_kv_cache_defrag(&ctx->kv_self);
}
// deprecated
bool llama_kv_cache_can_shift(const llama_context * ctx) {
return llama_kv_self_can_shift(ctx);
}
bool llama_kv_self_can_shift(const llama_context * ctx) {
return llama_kv_cache_can_shift(&ctx->kv_self);
}
// deprecated
void llama_kv_cache_update(llama_context * ctx) {
llama_kv_self_update(ctx);
}
// llama state API

View File

@ -62,6 +62,7 @@ struct llama_context {
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
bool logits_all = false;
bool need_reserve = false;
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
@ -87,6 +88,7 @@ struct llama_context {
// max token position across all sequences in the current context
llama_pos pos_max() const;
// certain implementations could require a padding for the context size
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
void reset();
@ -140,7 +142,7 @@ struct llama_context {
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
// return true if need to reserve new worst-case graph
bool kv_self_update();
void kv_self_update();
void build_attn_inp(
ggml_context * ctx0,

View File

@ -28,57 +28,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
// loading time will be recalculated after the first eval, so
// we take page faults deferred by mmap() into consideration
model.t_load_us = 0;
time_meas tm(model.t_load_us);
model.t_start_us = tm.t_start_us;
try {
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
ml.print_info();
model.hparams.vocab_only = params.vocab_only;
try {
model.load_arch(ml);
} catch(const std::exception & e) {
throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
}
try {
model.load_hparams(ml);
} catch(const std::exception & e) {
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
}
try {
model.load_vocab(ml);
} catch(const std::exception & e) {
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
}
model.load_stats(ml);
model.print_info();
if (params.vocab_only) {
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
return 0;
}
if (!model.load_tensors(ml)) {
return -2;
}
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
return -1;
}
return 0;
}
//
// llm_build
//
@ -7951,6 +7900,30 @@ static int llama_decode_impl(
}
}
// reserve a worst case graph if needed
// TODO: extract to a function
if (lctx.need_reserve) {
const auto & cparams = lctx.cparams;
const auto & model = lctx.model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched.get());
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
lctx.need_reserve = false;
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_backend_sched_reset(lctx.sched.get());
@ -8206,6 +8179,31 @@ static int llama_encode_impl(
lctx.prepare_decode(ubatch);
// reserve a worst case graph if needed
// TODO: extract to a function
if (lctx.need_reserve) {
// TODO: extract to a function
const auto & cparams = lctx.cparams;
const auto & model = lctx.model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched.get());
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
lctx.need_reserve = false;
}
ggml_backend_sched_reset(lctx.sched.get());
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
@ -8419,6 +8417,57 @@ int64_t llama_time_us(void) {
return ggml_time_us();
}
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
// loading time will be recalculated after the first eval, so
// we take page faults deferred by mmap() into consideration
model.t_load_us = 0;
time_meas tm(model.t_load_us);
model.t_start_us = tm.t_start_us;
try {
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
ml.print_info();
model.hparams.vocab_only = params.vocab_only;
try {
model.load_arch(ml);
} catch(const std::exception & e) {
throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
}
try {
model.load_hparams(ml);
} catch(const std::exception & e) {
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
}
try {
model.load_vocab(ml);
} catch(const std::exception & e) {
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
}
model.load_stats(ml);
model.print_info();
if (params.vocab_only) {
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
return 0;
}
if (!model.load_tensors(ml)) {
return -2;
}
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
return -1;
}
return 0;
}
static struct llama_model * llama_model_load_from_file_impl(
const std::string & path_model,
std::vector<std::string> & splits,
@ -8879,192 +8928,6 @@ struct llama_context * llama_new_context_with_model(
return llama_init_from_model(model, params);
}
//
// kv cache view
//
struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
}
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
llama_kv_cache_view_update(view, ctx->kv_self);
}
//
// kv cache
//
// deprecated
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
return llama_kv_self_n_tokens(ctx);
}
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
return llama_kv_cache_n_tokens(&ctx->kv_self);
}
// deprecated
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
return llama_kv_self_used_cells(ctx);
}
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
return llama_kv_cache_used_cells(&ctx->kv_self);
}
// deprecated
void llama_kv_cache_clear(llama_context * ctx) {
llama_kv_self_clear(ctx);
}
void llama_kv_self_clear(llama_context * ctx) {
llama_kv_cache_clear(&ctx->kv_self);
}
// deprecated
bool llama_kv_cache_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
}
bool llama_kv_self_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
}
// deprecated
void llama_kv_cache_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_self_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
// deprecated
void llama_kv_cache_seq_keep(
llama_context * ctx,
llama_seq_id seq_id) {
return llama_kv_self_seq_keep(ctx, seq_id);
}
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
}
// deprecated
void llama_kv_cache_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
}
void llama_kv_self_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
}
// deprecated
void llama_kv_cache_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
}
void llama_kv_self_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
}
// deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id);
}
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
}
// deprecated
void llama_kv_cache_defrag(llama_context * ctx) {
return llama_kv_self_defrag(ctx);
}
void llama_kv_self_defrag(llama_context * ctx) {
return llama_kv_cache_defrag(&ctx->kv_self);
}
// deprecated
bool llama_kv_cache_can_shift(const llama_context * ctx) {
return llama_kv_self_can_shift(ctx);
}
bool llama_kv_self_can_shift(const llama_context * ctx) {
return llama_kv_cache_can_shift(&ctx->kv_self);
}
// deprecated
void llama_kv_cache_update(llama_context * ctx) {
llama_kv_self_update(ctx);
}
void llama_kv_self_update(llama_context * ctx) {
const bool need_reserve = ctx->kv_self_update();
// reserve a worst case graph again
if (need_reserve) {
// TODO: extract to a function
const auto & cparams = ctx->cparams;
const auto & model = ctx->model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(ctx->sched.get());
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
}
}
///
int32_t llama_encode(