llama : fix batch split output count for embeddings

This commit is contained in:
Francis Couture-Harpin 2024-06-01 12:24:19 -04:00
parent 5d3c7b9585
commit 72eea49224

View File

@ -13730,7 +13730,9 @@ static int llama_decode_internal(
n_outputs = 1; n_outputs = 1;
} }
lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all); lctx.sbatch.from_batch(batch_all, n_embd,
/* legacy_split */ rs_self.size == 0,
/* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer // reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) { if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
@ -13740,6 +13742,7 @@ static int llama_decode_internal(
while (lctx.sbatch.n_tokens > 0) { while (lctx.sbatch.n_tokens > 0) {
// TODO: deprecate slice splits in favor of equal splits // TODO: deprecate slice splits in favor of equal splits
// For now, only use equal splits for recurrent or hybrid model architectures
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
const uint32_t n_tokens = u_batch.n_tokens; const uint32_t n_tokens = u_batch.n_tokens;