From 72eea49224e5b90263de08f8cddc6010353841eb Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 12:24:19 -0400 Subject: [PATCH] llama : fix batch split output count for embeddings --- llama.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 6878bc893..7c6afa7d1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13730,7 +13730,9 @@ static int llama_decode_internal( n_outputs = 1; } - lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all); + lctx.sbatch.from_batch(batch_all, n_embd, + /* legacy_split */ rs_self.size == 0, + /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { @@ -13740,6 +13742,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { // TODO: deprecate slice splits in favor of equal splits + // For now, only use equal splits for recurrent or hybrid model architectures llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); const uint32_t n_tokens = u_batch.n_tokens;