mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 20:43:07 +01:00
context : introduce llama_batch_manager
ggml-ci
This commit is contained in:
parent
2916a3bbbc
commit
c6fa715709
@ -32,6 +32,132 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
|
||||
return relative_bucket;
|
||||
}
|
||||
|
||||
struct llama_batch_manager : public llama_batch_manager_i {
|
||||
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & n_embd = hparams.n_embd;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ logits_all);
|
||||
}
|
||||
|
||||
~llama_batch_manager() override {
|
||||
}
|
||||
|
||||
virtual llama_ubatch next() override {
|
||||
ubatch = llama_ubatch();
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
const auto & n_ubatch = cparams.n_ubatch;
|
||||
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
||||
}
|
||||
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
virtual bool prepare() override {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_self_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot_info = kv_self.find_slot(ubatch);
|
||||
if (!slot_info) {
|
||||
return false;
|
||||
}
|
||||
|
||||
kv_slot_restorer.save(slot_info);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = kv_self.get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual void restore() override {
|
||||
kv_slot_restorer.restore(lctx.kv_self);
|
||||
}
|
||||
|
||||
virtual void update() override {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self.head += ubatch.n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void finalize() override {
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
|
||||
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
||||
|
||||
kv_self.defrag();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context & lctx;
|
||||
|
||||
const llama_batch & batch;
|
||||
|
||||
llama_ubatch ubatch;
|
||||
|
||||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
};
|
||||
|
||||
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
|
||||
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
|
||||
}
|
||||
|
||||
enum ggml_status llama_context::compute_graph(
|
||||
ggml_cgraph * graph,
|
||||
bool batched) {
|
||||
@ -59,7 +185,6 @@ enum ggml_status llama_context::compute_graph(
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
llama_pos llama_context::pos_max() const {
|
||||
return kv_self.pos_max();
|
||||
}
|
||||
@ -94,9 +219,6 @@ void llama_context::prepare_k_shift() {
|
||||
void llama_context::prepare_defrag() {
|
||||
}
|
||||
|
||||
void llama_context::prepare_decode(const llama_ubatch & /*ubatch*/) {
|
||||
}
|
||||
|
||||
// llama input
|
||||
|
||||
void llama_context::set_inputs(const llama_ubatch & ubatch) {
|
||||
|
@ -16,6 +16,20 @@
|
||||
|
||||
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
||||
|
||||
// TODO: this is very WIP - improve
|
||||
struct llama_batch_manager_i {
|
||||
virtual ~llama_batch_manager_i() = default;
|
||||
|
||||
//bool is_done() const;
|
||||
|
||||
virtual llama_ubatch next() = 0;
|
||||
|
||||
virtual bool prepare() = 0;
|
||||
virtual void restore() = 0;
|
||||
virtual void update() = 0;
|
||||
virtual void finalize() = 0;
|
||||
};
|
||||
|
||||
struct llama_context {
|
||||
llama_context(const llama_model & model)
|
||||
: model(model)
|
||||
@ -80,6 +94,9 @@ struct llama_context {
|
||||
ggml_abort_callback abort_callback = nullptr;
|
||||
void * abort_callback_data = nullptr;
|
||||
|
||||
// TODO: do not pass logits_all explicitly
|
||||
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch, bool logits_all);
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
enum ggml_status compute_graph(
|
||||
ggml_cgraph * graph,
|
||||
@ -95,7 +112,6 @@ struct llama_context {
|
||||
|
||||
void prepare_k_shift();
|
||||
void prepare_defrag();
|
||||
void prepare_decode(const llama_ubatch & ubatch);
|
||||
|
||||
void set_inputs(const llama_ubatch & ubatch);
|
||||
|
||||
|
@ -7807,8 +7807,6 @@ static int llama_decode_impl(
|
||||
uint32_t n_outputs = 0;
|
||||
uint32_t n_outputs_prev = 0;
|
||||
|
||||
const auto n_ubatch = cparams.n_ubatch;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
@ -7832,27 +7830,19 @@ static int llama_decode_impl(
|
||||
return -2;
|
||||
};
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||
const bool logits_all = n_outputs == n_tokens_all;
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ n_outputs == n_tokens_all);
|
||||
//auto & kv_self = lctx.kv_self;
|
||||
//llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||
|
||||
//lctx.sbatch.from_batch(batch, n_embd,
|
||||
// /* simple_split */ !kv_self.recurrent,
|
||||
// /* logits_all */ logits_all);
|
||||
|
||||
auto batch_manager = lctx.prepare_batch(batch, logits_all);
|
||||
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch;
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
||||
}
|
||||
llama_ubatch ubatch = batch_manager->next();
|
||||
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
@ -7873,32 +7863,10 @@ static int llama_decode_impl(
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
lctx.prepare_decode(ubatch);
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_self_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot_info = kv_self.find_slot(ubatch);
|
||||
if (!slot_info) {
|
||||
return 1;
|
||||
}
|
||||
kv_slot_restorer.save(slot_info);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = kv_self.get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
if (!batch_manager->prepare()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||
batch_manager->restore();
|
||||
return -3;
|
||||
}
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
@ -7963,7 +7931,7 @@ static int llama_decode_impl(
|
||||
|
||||
const auto compute_status = lctx.compute_graph(gf, n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
kv_slot_restorer.restore(kv_self);
|
||||
batch_manager->restore();
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
@ -7975,15 +7943,7 @@ static int llama_decode_impl(
|
||||
}
|
||||
}
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self.head += n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
}
|
||||
batch_manager->update();
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
@ -8061,6 +8021,7 @@ static int llama_decode_impl(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n_outputs_prev += lctx.n_outputs;
|
||||
}
|
||||
|
||||
@ -8089,17 +8050,7 @@ static int llama_decode_impl(
|
||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
//llama_synchronize(&lctx);
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
|
||||
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
||||
|
||||
kv_self.defrag();
|
||||
}
|
||||
}
|
||||
batch_manager->finalize();
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
@ -8178,7 +8129,7 @@ static int llama_encode_impl(
|
||||
lctx.inp_embd_enc = NULL;
|
||||
lctx.n_outputs = n_tokens;
|
||||
|
||||
lctx.prepare_decode(ubatch);
|
||||
//batch_manager->prepare(ubatch);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
// TODO: extract to a function
|
||||
|
Loading…
Reference in New Issue
Block a user