diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d696090cc..de54321df 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -32,6 +32,132 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t return relative_bucket; } +struct llama_batch_manager : public llama_batch_manager_i { + llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) { + const auto & hparams = lctx.model.hparams; + const auto & n_embd = hparams.n_embd; + + const auto & kv_self = lctx.kv_self; + + lctx.sbatch.from_batch(batch, n_embd, + /* simple_split */ !kv_self.recurrent, + /* logits_all */ logits_all); + } + + ~llama_batch_manager() override { + } + + virtual llama_ubatch next() override { + ubatch = llama_ubatch(); + + const auto & cparams = lctx.cparams; + const auto & kv_self = lctx.kv_self; + + const auto & n_ubatch = cparams.n_ubatch; + + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + if (kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = lctx.sbatch.split_seq(n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = lctx.sbatch.split_equal(n_ubatch); + } + } else { + ubatch = lctx.sbatch.split_simple(n_ubatch); + } + + return ubatch; + } + + virtual bool prepare() override { + const auto & cparams = lctx.cparams; + const auto & hparams = lctx.model.hparams; + + auto & kv_self = lctx.kv_self; + + // non-causal masks do not use the KV cache + if (hparams.causal_attn) { + llama_kv_self_update(&lctx); + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { + kv_self.head = 0; + } + + const auto slot_info = kv_self.find_slot(ubatch); + if (!slot_info) { + return false; + } + + kv_slot_restorer.save(slot_info); + + if (!kv_self.recurrent) { + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + const uint32_t pad = kv_self.get_padding(cparams); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); + } + } + + return true; + } + + virtual void restore() override { + kv_slot_restorer.restore(lctx.kv_self); + } + + virtual void update() override { + auto & kv_self = lctx.kv_self; + + // update the kv ring buffer + { + kv_self.head += ubatch.n_tokens; + + // Ensure kv cache head points to a valid index. + if (kv_self.head >= kv_self.size) { + kv_self.head = 0; + } + } + } + + virtual void finalize() override { + const auto & cparams = lctx.cparams; + + auto & kv_self = lctx.kv_self; + + // decide if we need to defrag the kv cache + if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) { + const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > cparams.defrag_thold) { + //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation); + + kv_self.defrag(); + } + } + } + + llama_context & lctx; + + const llama_batch & batch; + + llama_ubatch ubatch; + + llama_kv_slot_restorer kv_slot_restorer; +}; + +std::unique_ptr llama_context::prepare_batch(const llama_batch & batch, bool logits_all) { + return std::make_unique(*this, batch, logits_all); +} + enum ggml_status llama_context::compute_graph( ggml_cgraph * graph, bool batched) { @@ -59,7 +185,6 @@ enum ggml_status llama_context::compute_graph( return status; } - llama_pos llama_context::pos_max() const { return kv_self.pos_max(); } @@ -94,9 +219,6 @@ void llama_context::prepare_k_shift() { void llama_context::prepare_defrag() { } -void llama_context::prepare_decode(const llama_ubatch & /*ubatch*/) { -} - // llama input void llama_context::set_inputs(const llama_ubatch & ubatch) { diff --git a/src/llama-context.h b/src/llama-context.h index eb9a17391..47233f4f5 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -16,6 +16,20 @@ using llama_loras = std::unordered_map; +// TODO: this is very WIP - improve +struct llama_batch_manager_i { + virtual ~llama_batch_manager_i() = default; + + //bool is_done() const; + + virtual llama_ubatch next() = 0; + + virtual bool prepare() = 0; + virtual void restore() = 0; + virtual void update() = 0; + virtual void finalize() = 0; +}; + struct llama_context { llama_context(const llama_model & model) : model(model) @@ -80,6 +94,9 @@ struct llama_context { ggml_abort_callback abort_callback = nullptr; void * abort_callback_data = nullptr; + // TODO: do not pass logits_all explicitly + std::unique_ptr prepare_batch(const llama_batch & batch, bool logits_all); + // returns the result of ggml_backend_sched_graph_compute_async execution enum ggml_status compute_graph( ggml_cgraph * graph, @@ -95,7 +112,6 @@ struct llama_context { void prepare_k_shift(); void prepare_defrag(); - void prepare_decode(const llama_ubatch & ubatch); void set_inputs(const llama_ubatch & ubatch); diff --git a/src/llama.cpp b/src/llama.cpp index ab0d9dda6..89de9da8e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7807,8 +7807,6 @@ static int llama_decode_impl( uint32_t n_outputs = 0; uint32_t n_outputs_prev = 0; - const auto n_ubatch = cparams.n_ubatch; - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; @@ -7832,27 +7830,19 @@ static int llama_decode_impl( return -2; }; - auto & kv_self = lctx.kv_self; - llama_kv_slot_restorer kv_slot_restorer(kv_self); + const bool logits_all = n_outputs == n_tokens_all; - lctx.sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, - /* logits_all */ n_outputs == n_tokens_all); + //auto & kv_self = lctx.kv_self; + //llama_kv_slot_restorer kv_slot_restorer(kv_self); + + //lctx.sbatch.from_batch(batch, n_embd, + // /* simple_split */ !kv_self.recurrent, + // /* logits_all */ logits_all); + + auto batch_manager = lctx.prepare_batch(batch, logits_all); while (lctx.sbatch.n_tokens > 0) { - llama_ubatch ubatch; - if (kv_self.recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = lctx.sbatch.split_seq(n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = lctx.sbatch.split_equal(n_ubatch); - } - } else { - ubatch = lctx.sbatch.split_simple(n_ubatch); - } + llama_ubatch ubatch = batch_manager->next(); const uint32_t n_tokens = ubatch.n_tokens; @@ -7873,32 +7863,10 @@ static int llama_decode_impl( lctx.n_outputs = n_outputs_new; } - lctx.prepare_decode(ubatch); - - // non-causal masks do not use the KV cache - if (hparams.causal_attn) { - llama_kv_self_update(&lctx); - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } - - const auto slot_info = kv_self.find_slot(ubatch); - if (!slot_info) { - return 1; - } - kv_slot_restorer.save(slot_info); - - if (!kv_self.recurrent) { - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv_self.get_padding(cparams); - kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); - } + if (!batch_manager->prepare()) { + LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); + batch_manager->restore(); + return -3; } // reserve a worst case graph if needed @@ -7963,7 +7931,7 @@ static int llama_decode_impl( const auto compute_status = lctx.compute_graph(gf, n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { - kv_slot_restorer.restore(kv_self); + batch_manager->restore(); switch (compute_status) { case GGML_STATUS_ABORTED: return 2; @@ -7975,15 +7943,7 @@ static int llama_decode_impl( } } - // update the kv ring buffer - { - kv_self.head += n_tokens; - - // Ensure kv cache head points to a valid index. - if (kv_self.head >= kv_self.size) { - kv_self.head = 0; - } - } + batch_manager->update(); // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) { @@ -8061,6 +8021,7 @@ static int llama_decode_impl( } } } + n_outputs_prev += lctx.n_outputs; } @@ -8089,17 +8050,7 @@ static int llama_decode_impl( // wait for the computation to finish (automatically done when obtaining the model output) //llama_synchronize(&lctx); - // decide if we need to defrag the kv cache - if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) { - const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > cparams.defrag_thold) { - //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation); - - kv_self.defrag(); - } - } + batch_manager->finalize(); // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. @@ -8178,7 +8129,7 @@ static int llama_encode_impl( lctx.inp_embd_enc = NULL; lctx.n_outputs = n_tokens; - lctx.prepare_decode(ubatch); + //batch_manager->prepare(ubatch); // reserve a worst case graph if needed // TODO: extract to a function