From fb4a0ec0833c71cff5a1a367ba375447ce6106eb Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Wed, 13 Nov 2024 20:00:35 +0200 Subject: [PATCH] llama : propagate the results of `graph_compute` (#9525) * llama: propagating the results of `graph_compute` to the user interface * llama: reverting kv_cache in case of failed compute * llama: `llama_kv_cache_state` was removed, only the result of `llama_graph_compute` is returned * llama: restore a kv_cache in case of failed computation * llama: correct reverting of the entire batch. also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models * llama: updated comments * llama : add comments about KV cache state after error --------- Co-authored-by: Georgi Gerganov --- include/llama.h | 4 +- src/llama.cpp | 120 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 109 insertions(+), 15 deletions(-) diff --git a/include/llama.h b/include/llama.h index ccb48f73c..5e742642e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -797,7 +797,7 @@ extern "C" { // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success - // < 0 - error + // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); @@ -805,7 +805,7 @@ extern "C" { // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - // < 0 - error + // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch); diff --git a/src/llama.cpp b/src/llama.cpp index 4d89c5222..97eee26a5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3502,11 +3502,24 @@ static bool llama_kv_cache_init( return true; } +// a structure holds information about the slot found in llama_kv_cache_find_slot +struct llama_kv_cache_slot_info { + std::pair boundaries; // slot boundaries [begin, end) + bool found = false; // the slot was found + + explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} + llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} + + operator bool() const { return found; } +}; +static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; + // find an empty slot of size "n_tokens" in the cache // updates the cache head +// returns a structure holding information about the slot found // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( +static struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_ubatch & batch) { const uint32_t n_tokens = batch.n_tokens; @@ -3534,7 +3547,7 @@ static bool llama_kv_cache_find_slot( // too big seq_id // TODO: would it be possible to resize the cache instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); - return false; + return llama_kv_cache_slot_info_failed; } if (j > 0) { llama_kv_cell & seq = cache.cells[seq_id]; @@ -3669,15 +3682,17 @@ static bool llama_kv_cache_find_slot( // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; + cache.used = std::count_if(cache.cells.begin(), cache.cells.end(), + [](const llama_kv_cell& cell){ return !cell.is_empty(); }); // sanity check - return cache.n >= n_seqs; + return llama_kv_cache_slot_info(cache.n >= n_seqs); } // otherwise, one cell per token. if (n_tokens > cache.size) { LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); - return false; + return llama_kv_cache_slot_info_failed; } uint32_t n_tested = 0; @@ -3705,7 +3720,7 @@ static bool llama_kv_cache_find_slot( if (n_tested >= cache.size) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + return llama_kv_cache_slot_info_failed; } } @@ -3722,7 +3737,7 @@ static bool llama_kv_cache_find_slot( cache.used += n_tokens; - return true; + return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens); } // find how many cells are currently in use @@ -3998,6 +4013,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) return cparams.flash_attn ? 256u : 32u; } +// saves the kv_cache state for future recovery. +// used to rollback llama_kv_cache_find_slot changes. +struct llama_kv_slot_restorer { + struct llama_kv_cache_state { + uint32_t head = 0; + uint32_t n = 0; + } old_state; + + // for non-recurrent models only + // list of slots to restore + std::vector> slot_boundaries; + + bool do_restore = false; + + explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { + old_state.head = cache.head; + old_state.n = cache.n; + } + + // saves a slot information for future restoration + void save(const struct llama_kv_cache_slot_info & slot) { + if (slot) { + do_restore = true; + if (slot.boundaries.first != slot.boundaries.second) { + slot_boundaries.push_back(slot.boundaries); + } + } + } + + // must be explicitly called to restore the kv_cache state + // and rollback changes from all llama_kv_cache_find_slot calls + void restore(struct llama_kv_cache & cache) { + if (do_restore) { + cache.head = old_state.head; + cache.n = old_state.n; + + if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased + llama_kv_cache_seq_rm(cache, -1, -1, -1); + } else { + for (auto & slot : slot_boundaries) { + llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second); + } + } + } + } +}; + // // model loading and saving // @@ -17181,7 +17243,8 @@ static void llama_output_reorder(struct llama_context * ctx) { } } -static void llama_graph_compute( +// returns the result of ggml_backend_sched_graph_compute_async execution +static enum ggml_status llama_graph_compute( llama_context & lctx, ggml_cgraph * gf, int n_threads, @@ -17196,15 +17259,20 @@ static void llama_graph_compute( set_n_threads_fn.second(set_n_threads_fn.first, n_threads); } - auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf); - if (err != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err); + auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); } // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); + + return status; } // decode a batch of tokens by evaluating the transformer +// in case of unsuccessful decoding (error or warning), +// the kv_cache state will be returned to its original state +// (for non-recurrent models) or cleaned (for recurrent models) // // - lctx: llama context // - batch: batch to evaluate @@ -17254,6 +17322,7 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; + llama_kv_slot_restorer kv_slot_restorer(kv_self); const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17338,9 +17407,11 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, ubatch)) { + const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); + if (!slot) { return 1; } + kv_slot_restorer.save(slot); if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized @@ -17387,7 +17458,19 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); - llama_graph_compute(lctx, gf, n_threads, threadpool); + const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); + if (compute_status != GGML_STATUS_SUCCESS) { + kv_slot_restorer.restore(kv_self); + switch (compute_status) { + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } + } // update the kv ring buffer { @@ -17624,7 +17707,18 @@ static int llama_encode_internal( llama_set_inputs(lctx, ubatch); - llama_graph_compute(lctx, gf, n_threads, threadpool); + const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); + switch (compute_status) { + case GGML_STATUS_SUCCESS: + break; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } // extract embeddings if (embd) {