mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
llama : fix edge case finding batch seq_id of split recurrent cell
This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash.
This commit is contained in:
parent
18d1c14047
commit
61200ef29f
12
llama.cpp
12
llama.cpp
@ -3879,11 +3879,17 @@ static bool llama_cache_find_slot(
|
|||||||
if (cell.tail_rc == 0) {
|
if (cell.tail_rc == 0) {
|
||||||
cache.rs.clear_cell(cell);
|
cache.rs.clear_cell(cell);
|
||||||
} else {
|
} else {
|
||||||
// TODO: does this always work correctly
|
// Find the seq_id of the first tail of this cell
|
||||||
// even if there are more than one seq_node in this cell?
|
llama_seq_id seq_id = -1;
|
||||||
|
for (llama_rs_seq_node & seq_node : cell.seq_nodes) {
|
||||||
|
if (seq_node.is_tail()) {
|
||||||
|
seq_id = seq_node.seq_id;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ASSERT(seq_id != -1);
|
||||||
|
|
||||||
// Which seq_id of the batch is it?
|
// Which seq_id of the batch is it?
|
||||||
llama_seq_id seq_id = cell.seq_nodes[0].seq_id;
|
|
||||||
int32_t nth_seq_id = -1;
|
int32_t nth_seq_id = -1;
|
||||||
for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) {
|
for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) {
|
||||||
if (seq_id == batch.seq_id[s][0]) {
|
if (seq_id == batch.seq_id[s][0]) {
|
||||||
|
Loading…
Reference in New Issue
Block a user