diff --git a/llama.cpp b/llama.cpp index 6878bc893..7c6afa7d1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13730,7 +13730,9 @@ static int llama_decode_internal( n_outputs = 1; } - lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all); + lctx.sbatch.from_batch(batch_all, n_embd, + /* legacy_split */ rs_self.size == 0, + /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { @@ -13740,6 +13742,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { // TODO: deprecate slice splits in favor of equal splits + // For now, only use equal splits for recurrent or hybrid model architectures llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); const uint32_t n_tokens = u_batch.n_tokens;