diff --git a/src/llama.cpp b/src/llama.cpp index e8cfe5012..eed90d5de 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8432,13 +8432,141 @@ static enum ggml_status llama_graph_compute( return status; } +static int llama_prepare_sbatch( + llama_context & lctx, + const llama_batch & batch, + uint32_t & n_outputs) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const uint32_t n_tokens_all = batch.n_tokens; + const int64_t n_embd = hparams.n_embd; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + if (batch.token) { + for (uint32_t i = 0; i < n_tokens_all; ++i) { + if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return -1; + } + } + } + GGML_ASSERT(n_tokens_all <= cparams.n_batch); + GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); + + lctx.n_queued_tokens += n_tokens_all; + lctx.embd_seq.clear(); + + // count outputs + if (batch.logits && !embd_pooled) { + for (uint32_t i = 0; i < n_tokens_all; ++i) { + n_outputs += batch.logits[i] != 0; + } + } else if (lctx.logits_all || embd_pooled) { + n_outputs = n_tokens_all; + } else { + // keep last output only + n_outputs = 1; + } + + lctx.sbatch.from_batch(batch, n_embd, + /* simple_split */ !lctx.kv_self.recurrent, + /* logits_all */ n_outputs == n_tokens_all); + + // reserve output buffer + if (llama_output_reserve(lctx, n_outputs) < n_outputs) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); + return -2; + }; + + return 0; +} + +static int llama_prepare_ubatch( + llama_context & lctx, + llama_kv_slot_restorer & kv_slot_restorer, + llama_ubatch & ubatch, + const uint32_t n_outputs, + const uint32_t n_tokens_all) { + GGML_ASSERT(lctx.sbatch.n_tokens > 0); + + auto & kv_self = lctx.kv_self; + const auto & cparams = lctx.cparams; + const auto & hparams = lctx.model.hparams; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + if (lctx.kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = lctx.sbatch.split_seq(cparams.n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = lctx.sbatch.split_equal(cparams.n_ubatch); + } + } else { + ubatch = lctx.sbatch.split_simple(cparams.n_ubatch); + } + + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; + + if (n_outputs == n_tokens_all) { + n_outputs_new = ubatch.n_tokens; + } else { + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + n_outputs_new += int32_t(ubatch.output[i] != 0); + } + } + + // needs to happen before the graph is built + lctx.n_outputs = n_outputs_new; + } + + // non-causal masks do not use the KV cache + if (hparams.causal_attn) { + llama_kv_cache_update(&lctx); + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { + kv_self.head = 0; + } + + const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); + if (!slot) { + return 1; + } + kv_slot_restorer.save(slot); + + if (!kv_self.recurrent) { + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + const uint32_t pad = llama_kv_cache_get_padding(cparams); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); + } + } + + return 0; +} + // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), // the kv_cache state will be returned to its original state // (for non-recurrent models) or cleaned (for recurrent models) // // - lctx: llama context -// - batch: batch to evaluate +// - inp_batch: batch to evaluate // // return 0 on success // return positive int on warning @@ -8455,37 +8583,18 @@ static int llama_decode_impl( return -1; } - // temporary allocate memory for the input batch if needed + // temporarily allocate memory for the input batch if needed llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); - const llama_batch & batch = batch_allocr.batch; - const uint32_t n_tokens_all = batch.n_tokens; const auto & model = lctx.model; const auto & vocab = model.vocab; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - - if (batch.token) { - for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); - return -1; - } - } - } - - GGML_ASSERT(n_tokens_all <= cparams.n_batch); - - GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); - if (lctx.t_compute_start_us == 0) { lctx.t_compute_start_us = ggml_time_us(); } - lctx.n_queued_tokens += n_tokens_all; - auto & kv_self = lctx.kv_self; llama_kv_slot_restorer kv_slot_restorer(kv_self); @@ -8495,99 +8604,27 @@ static int llama_decode_impl( uint32_t n_outputs = 0; uint32_t n_outputs_prev = 0; - const auto n_ubatch = cparams.n_ubatch; - - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens - const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - - lctx.embd_seq.clear(); - - // count outputs - if (batch.logits && !embd_pooled) { - for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs += batch.logits[i] != 0; + { + const int ret = llama_prepare_sbatch(lctx, batch, n_outputs); + if (ret != 0) { + return ret; } - } else if (lctx.logits_all || embd_pooled) { - n_outputs = n_tokens_all; - } else { - // keep last output only - n_outputs = 1; } - lctx.sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, - /* logits_all */ n_outputs == n_tokens_all); - - // reserve output buffer - if (llama_output_reserve(lctx, n_outputs) < n_outputs) { - LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); - return -2; - }; - while (lctx.sbatch.n_tokens > 0) { llama_ubatch ubatch; - if (kv_self.recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = lctx.sbatch.split_seq(n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = lctx.sbatch.split_equal(n_ubatch); - } - } else { - ubatch = lctx.sbatch.split_simple(n_ubatch); - } - const uint32_t n_tokens = ubatch.n_tokens; - - // count the outputs in this u_batch { - int32_t n_outputs_new = 0; - - if (n_outputs == n_tokens_all) { - n_outputs_new = n_tokens; - } else { - GGML_ASSERT(ubatch.output); - for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += (int32_t) (ubatch.output[i] != 0); - } + const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens); + if (ret != 0) { + return ret; } - - // needs to happen before the graph is built - lctx.n_outputs = n_outputs_new; } - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; + const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; GGML_ASSERT(n_threads > 0); - // non-causal masks do not use the KV cache - if (hparams.causal_attn) { - llama_kv_cache_update(&lctx); - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } - - const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); - if (!slot) { - return 1; - } - kv_slot_restorer.save(slot); - - if (!kv_self.recurrent) { - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = llama_kv_cache_get_padding(cparams); - kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); - } - } - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); ggml_backend_sched_reset(lctx.sched.get()); @@ -8640,7 +8677,7 @@ static int llama_decode_impl( // update the kv ring buffer { - kv_self.head += n_tokens; + kv_self.head += ubatch.n_tokens; // Ensure kv cache head points to a valid index. if (kv_self.head >= kv_self.size) {