context : introduce llama_batch_manager

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-17 20:30:16 +02:00
parent 2916a3bbbc
commit c6fa715709
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 162 additions and 73 deletions

View File

@ -32,6 +32,132 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
return relative_bucket;
}
struct llama_batch_manager : public llama_batch_manager_i {
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
const auto & hparams = lctx.model.hparams;
const auto & n_embd = hparams.n_embd;
const auto & kv_self = lctx.kv_self;
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ logits_all);
}
~llama_batch_manager() override {
}
virtual llama_ubatch next() override {
ubatch = llama_ubatch();
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
const auto & n_ubatch = cparams.n_ubatch;
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
return ubatch;
}
virtual bool prepare() override {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
auto & kv_self = lctx.kv_self;
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_self_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot_info = kv_self.find_slot(ubatch);
if (!slot_info) {
return false;
}
kv_slot_restorer.save(slot_info);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self.get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
return true;
}
virtual void restore() override {
kv_slot_restorer.restore(lctx.kv_self);
}
virtual void update() override {
auto & kv_self = lctx.kv_self;
// update the kv ring buffer
{
kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}
}
virtual void finalize() override {
const auto & cparams = lctx.cparams;
auto & kv_self = lctx.kv_self;
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
kv_self.defrag();
}
}
}
llama_context & lctx;
const llama_batch & batch;
llama_ubatch ubatch;
llama_kv_slot_restorer kv_slot_restorer;
};
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
}
enum ggml_status llama_context::compute_graph(
ggml_cgraph * graph,
bool batched) {
@ -59,7 +185,6 @@ enum ggml_status llama_context::compute_graph(
return status;
}
llama_pos llama_context::pos_max() const {
return kv_self.pos_max();
}
@ -94,9 +219,6 @@ void llama_context::prepare_k_shift() {
void llama_context::prepare_defrag() {
}
void llama_context::prepare_decode(const llama_ubatch & /*ubatch*/) {
}
// llama input
void llama_context::set_inputs(const llama_ubatch & ubatch) {

View File

@ -16,6 +16,20 @@
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
// TODO: this is very WIP - improve
struct llama_batch_manager_i {
virtual ~llama_batch_manager_i() = default;
//bool is_done() const;
virtual llama_ubatch next() = 0;
virtual bool prepare() = 0;
virtual void restore() = 0;
virtual void update() = 0;
virtual void finalize() = 0;
};
struct llama_context {
llama_context(const llama_model & model)
: model(model)
@ -80,6 +94,9 @@ struct llama_context {
ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;
// TODO: do not pass logits_all explicitly
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch, bool logits_all);
// returns the result of ggml_backend_sched_graph_compute_async execution
enum ggml_status compute_graph(
ggml_cgraph * graph,
@ -95,7 +112,6 @@ struct llama_context {
void prepare_k_shift();
void prepare_defrag();
void prepare_decode(const llama_ubatch & ubatch);
void set_inputs(const llama_ubatch & ubatch);

View File

@ -7807,8 +7807,6 @@ static int llama_decode_impl(
uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0;
const auto n_ubatch = cparams.n_ubatch;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
@ -7832,27 +7830,19 @@ static int llama_decode_impl(
return -2;
};
auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self);
const bool logits_all = n_outputs == n_tokens_all;
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
//auto & kv_self = lctx.kv_self;
//llama_kv_slot_restorer kv_slot_restorer(kv_self);
//lctx.sbatch.from_batch(batch, n_embd,
// /* simple_split */ !kv_self.recurrent,
// /* logits_all */ logits_all);
auto batch_manager = lctx.prepare_batch(batch, logits_all);
while (lctx.sbatch.n_tokens > 0) {
llama_ubatch ubatch;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
llama_ubatch ubatch = batch_manager->next();
const uint32_t n_tokens = ubatch.n_tokens;
@ -7873,32 +7863,10 @@ static int llama_decode_impl(
lctx.n_outputs = n_outputs_new;
}
lctx.prepare_decode(ubatch);
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_self_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}
const auto slot_info = kv_self.find_slot(ubatch);
if (!slot_info) {
return 1;
}
kv_slot_restorer.save(slot_info);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self.get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
if (!batch_manager->prepare()) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
batch_manager->restore();
return -3;
}
// reserve a worst case graph if needed
@ -7963,7 +7931,7 @@ static int llama_decode_impl(
const auto compute_status = lctx.compute_graph(gf, n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
kv_slot_restorer.restore(kv_self);
batch_manager->restore();
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
@ -7975,15 +7943,7 @@ static int llama_decode_impl(
}
}
// update the kv ring buffer
{
kv_self.head += n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}
batch_manager->update();
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
@ -8061,6 +8021,7 @@ static int llama_decode_impl(
}
}
}
n_outputs_prev += lctx.n_outputs;
}
@ -8089,17 +8050,7 @@ static int llama_decode_impl(
// wait for the computation to finish (automatically done when obtaining the model output)
//llama_synchronize(&lctx);
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
kv_self.defrag();
}
}
batch_manager->finalize();
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
@ -8178,7 +8129,7 @@ static int llama_encode_impl(
lctx.inp_embd_enc = NULL;
lctx.n_outputs = n_tokens;
lctx.prepare_decode(ubatch);
//batch_manager->prepare(ubatch);
// reserve a worst case graph if needed
// TODO: extract to a function