cont : move kv_self update to llama_context

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-16 21:55:12 +02:00
parent 556155a525
commit 12719550a6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 157 additions and 154 deletions

View File

@ -32,6 +32,38 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
return relative_bucket;
}
enum ggml_status llama_context::compute_graph(
ggml_cgraph * graph,
bool batched) {
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
if (backend_cpu != nullptr) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
set_threadpool_fn(backend_cpu, tp);
}
// set the number of threads for all the backends
for (const auto & set_n_threads_fn : set_n_threads_fns) {
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
}
auto status = ggml_backend_sched_graph_compute_async(sched.get(), graph);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
}
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
return status;
}
llama_pos llama_context::pos_max() const {
return kv_self.pos_max();
}
// TODO: improve
void llama_context::reset() {
inp_tokens = nullptr;
@ -540,6 +572,93 @@ ggml_tensor * llama_context::build_lora_mm_id(
return res;
}
bool llama_context::kv_self_update() {
bool need_reserve = false;
auto & kv = kv_self;
if (kv.has_shift) {
if (!kv.can_shift) {
GGML_ABORT("The current context does not support K-shift");
}
// apply K-shift if needed
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
prepare_k_shift();
ggml_backend_sched_reset(sched.get());
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ggml_context * ctx0 = ggml_init(params);
reset();
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
build_k_shift(ctx0, gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
set_inputs({});
compute_graph(gf, false);
ggml_free(ctx0);
need_reserve = true;
}
{
kv.has_shift = false;
for (uint32_t i = 0; i < kv.size; ++i) {
kv.cells[i].delta = 0;
}
}
}
// defragment the KV cache if needed
if (kv.do_defrag) {
prepare_defrag();
ggml_backend_sched_reset(sched.get());
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ggml_context * ctx0 = ggml_init(params);
reset();
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
build_defrag(ctx0, gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
// no input
//set_inputs({});
compute_graph(gf, false);
ggml_free(ctx0);
need_reserve = true;
kv.do_defrag = false;
}
return need_reserve;
}
void llama_context::build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,

View File

@ -79,6 +79,13 @@ struct llama_context {
ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;
// returns the result of ggml_backend_sched_graph_compute_async execution
enum ggml_status compute_graph(
ggml_cgraph * graph,
bool batched);
llama_pos pos_max() const;
void reset();
void prepare_k_shift();
@ -129,6 +136,9 @@ struct llama_context {
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
// return true if need to reserve new worst-case graph
bool kv_self_update();
void build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,

View File

@ -110,7 +110,6 @@ struct llm_build_context {
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
//const llama_kv_cache & kv_self;
const llama_adapter_cvec & cvec;
const llama_loras & loras;
@ -137,8 +136,6 @@ struct llm_build_context {
const float norm_rms_eps;
const int32_t n_tokens;
//const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
//const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_outputs;
const int32_t n_outputs_enc;
const int32_t n_ctx_orig;
@ -166,7 +163,6 @@ struct llm_build_context {
hparams (model.hparams),
cparams (lctx.cparams),
ubatch (ubatch),
//kv_self (lctx.kv_self),
cvec (lctx.cvec),
loras (lctx.loras),
n_embd (hparams.n_embd),
@ -190,8 +186,6 @@ struct llm_build_context {
norm_eps (hparams.f_norm_eps),
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (ubatch.n_tokens),
//n_kv (worst_case ? kv_self.size : kv_self.n),
//kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
n_ctx_orig (cparams.n_ctx_orig_yarn),
@ -7532,40 +7526,6 @@ struct llm_build_context {
}
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx) {
llama_ubatch dummy = {};
dummy.equal_seqs = true;
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
struct llm_build_context llm(lctx, dummy, cb, false);
llm.init();
struct ggml_cgraph * result = llm.build_defrag();
llm.free();
return result;
}
static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
llama_ubatch dummy = {};
dummy.equal_seqs = true;
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
struct llm_build_context llm(lctx, dummy, cb, false);
llm.init();
struct ggml_cgraph * result = llm.build_k_shift();
llm.free();
return result;
}
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_ubatch & ubatch,
@ -7836,33 +7796,6 @@ static struct ggml_cgraph * llama_build_graph(
return result;
}
// returns the result of ggml_backend_sched_graph_compute_async execution
static enum ggml_status llama_graph_compute(
llama_context & lctx,
ggml_cgraph * gf,
int n_threads,
ggml_threadpool * threadpool) {
if (lctx.backend_cpu != nullptr) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(lctx.backend_cpu));
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
set_threadpool_fn(lctx.backend_cpu, threadpool);
}
// set the number of threads for all the backends
for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
}
auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
}
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
return status;
}
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
@ -7887,7 +7820,7 @@ static int llama_decode_impl(
}
// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.pos_max() + 1);
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;
@ -7989,16 +7922,11 @@ static int llama_decode_impl(
lctx.n_outputs = n_outputs_new;
}
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
GGML_ASSERT(n_threads > 0);
lctx.prepare_decode(ubatch);
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_self_update(&lctx); // TODO: lctx->kv_self_update()
llama_kv_self_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
@ -8058,7 +7986,7 @@ static int llama_decode_impl(
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
const auto compute_status = lctx.compute_graph(gf, n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
kv_slot_restorer.restore(kv_self);
switch (compute_status) {
@ -8226,7 +8154,7 @@ static int llama_encode_impl(
}
// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.pos_max() + 1);
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens = batch.n_tokens;
@ -8274,11 +8202,6 @@ static int llama_encode_impl(
lctx.inp_embd_enc = NULL;
lctx.n_outputs = n_tokens;
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
GGML_ASSERT(n_threads > 0);
lctx.prepare_decode(ubatch);
ggml_backend_sched_reset(lctx.sched.get());
@ -8310,7 +8233,7 @@ static int llama_encode_impl(
}
}
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
const auto compute_status = lctx.compute_graph(gf, n_tokens > 1);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
@ -8397,76 +8320,6 @@ static int llama_encode_impl(
return 0;
}
// TODO: move to llama_context
static void llama_kv_self_update_impl(llama_context & lctx) {
bool need_reserve = false;
auto & kv = lctx.kv_self;
if (kv.has_shift) {
if (!kv.can_shift) {
GGML_ABORT("The current context does not support K-shift");
}
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
lctx.prepare_k_shift();
ggml_backend_sched_reset(lctx.sched.get());
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
lctx.set_inputs({});
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
need_reserve = true;
}
{
kv.has_shift = false;
for (uint32_t i = 0; i < kv.size; ++i) {
kv.cells[i].delta = 0;
}
}
}
// defragment the KV cache if needed
if (kv.do_defrag) {
lctx.prepare_defrag();
ggml_backend_sched_reset(lctx.sched.get());
ggml_cgraph * gf = llama_build_graph_defrag(lctx);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
need_reserve = true;
kv.do_defrag = false;
}
// reserve a worst case graph again
if (need_reserve) {
// TODO: extract to a function
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched.get());
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
}
}
int32_t llama_set_adapter_lora(
struct llama_context * ctx,
struct llama_adapter_lora * adapter,
@ -9214,9 +9067,30 @@ void llama_kv_cache_update(llama_context * ctx) {
llama_kv_self_update(ctx);
}
// TODO: move to llama-context
void llama_kv_self_update(llama_context * ctx) {
llama_kv_self_update_impl(*ctx);
const bool need_reserve = ctx->kv_self_update();
// reserve a worst case graph again
if (need_reserve) {
// TODO: extract to a function
const auto & cparams = ctx->cparams;
const auto & model = ctx->model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(ctx->sched.get());
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
}
}
///