mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-19 08:20:10 +01:00
llama : fix batch split output count for embeddings
This commit is contained in:
parent
5d3c7b9585
commit
72eea49224
@ -13730,7 +13730,9 @@ static int llama_decode_internal(
|
||||
n_outputs = 1;
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all);
|
||||
lctx.sbatch.from_batch(batch_all, n_embd,
|
||||
/* legacy_split */ rs_self.size == 0,
|
||||
/* logits_all */ n_outputs == n_tokens_all);
|
||||
|
||||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||
@ -13740,6 +13742,7 @@ static int llama_decode_internal(
|
||||
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
// TODO: deprecate slice splits in favor of equal splits
|
||||
// For now, only use equal splits for recurrent or hybrid model architectures
|
||||
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
|
||||
const uint32_t n_tokens = u_batch.n_tokens;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user