From 61200ef29fc0e76f264ada583b77e9228120779f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 16:41:22 -0400 Subject: [PATCH] llama : fix edge case finding batch seq_id of split recurrent cell This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash. --- llama.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index d44dfe7b2..62d66c2bc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3879,11 +3879,17 @@ static bool llama_cache_find_slot( if (cell.tail_rc == 0) { cache.rs.clear_cell(cell); } else { - // TODO: does this always work correctly - // even if there are more than one seq_node in this cell? + // Find the seq_id of the first tail of this cell + llama_seq_id seq_id = -1; + for (llama_rs_seq_node & seq_node : cell.seq_nodes) { + if (seq_node.is_tail()) { + seq_id = seq_node.seq_id; + break; + } + } + GGML_ASSERT(seq_id != -1); // Which seq_id of the batch is it? - llama_seq_id seq_id = cell.seq_nodes[0].seq_id; int32_t nth_seq_id = -1; for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { if (seq_id == batch.seq_id[s][0]) {