From 0c8b3b20956521acc8f1f297cb58ab3172b3c3e7 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 9 Apr 2024 17:35:22 -0400 Subject: [PATCH] llama : correctly handle more edge cases for the rs cache --- llama.cpp | 407 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 211 insertions(+), 196 deletions(-) diff --git a/llama.cpp b/llama.cpp index d561f80f6..5433bde86 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2034,7 +2034,7 @@ struct llama_rs_cache { uint32_t n = 0; // range of states used for the last slot // useful to know the minimum reserved cell count per seq_id - // only counts sequences with n_cells > 0 AND which have a non-shared tail + // only counts sequences which have a non-shared tail uint32_t n_seqs = 0; // cells part of multiple sequences AND which have at least one tail uint32_t n_shared_tail_cells = 0; @@ -2082,21 +2082,37 @@ struct llama_rs_cache { llama_rs_cell & cell = cells[cell_id]; if (cell.seq_nodes.empty()) { if (cell.pos >= 0) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } cell.pos = -1; was_valid = false; } } if (cell.pos < 0) { if (cell.pos != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } cell.pos = -1; was_valid = false; } if (!cell.seq_nodes.empty()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n", + __func__, cell_id, cell.seq_nodes.size()); + } cell.seq_nodes.clear(); was_valid = false; } cell.src = -1; if (cell.prev != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.prev); + } cell.prev = -1; was_valid = false; } @@ -2213,17 +2229,15 @@ struct llama_rs_cache { // n_seqs uint32_t n_seqs_verif = 0; uint32_t n_shared_tail_cells_verif = 0; - for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { - auto & seq = seq_tails[seq_id]; - if (seq.tail >= 0) { - llama_rs_cell & tail_cell = cells[seq.tail]; - // NOTE: could also have checked if n_cells > 0 - if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) { - if (tail_cell.seq_nodes.size() > 1) { - n_shared_tail_cells_verif += 1; - } else { + for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { + llama_rs_cell & rs_cell = cells[cell_id]; + if (!rs_cell.seq_nodes.empty()) { + if (rs_cell.seq_nodes.size() == 1) { + if (rs_cell.tail_rc == 1) { n_seqs_verif += 1; } + } else if (rs_cell.tail_rc > 0) { + n_shared_tail_cells_verif += 1; } } } @@ -2246,72 +2260,15 @@ struct llama_rs_cache { return was_valid; } - // returns whether or not a cell was freed - void clear_cell(llama_rs_cell & rs_cell) { - GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - if (!rs_cell.is_empty()) { - // update sequence tree links - bool first = true; - for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - // NOTE: if all next cells are the same cell, this should still work - cells[node.next_cell].prev = rs_cell.prev; - } - // next_cell of the nodes of the previous cell - if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { - llama_rs_cell & prev_cell = cells[rs_cell.prev]; - auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); - // assuming the previous node is always found - GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); - prev_node->next_cell = node.next_cell; - if (node.is_tail()) { - prev_cell.tail_rc += 1; - } - } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - // update tail - if (node.is_tail()) { - seq.tail = rs_cell.prev; - } - // cell counts - if (first) { - seq.n_cells -= 1; - if (rs_cell.tail_rc > 0 && seq.tail < 0) { - // last tail cell - if (rs_cell.seq_nodes.size() > 1) { - n_shared_tail_cells -= 1; - } else { - n_seqs -= 1; - } - } - first = false; - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } - rs_cell.pos = -1; - rs_cell.src = -1; - rs_cell.prev = -1; - rs_cell.tail_rc = 0; - rs_cell.seq_nodes.clear(); - used -= 1; - } - } - // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); // TODO: assert the iterator points inside the correct vector if (node_iter != rs_cell.seq_nodes.end()) { - if (rs_cell.seq_nodes.size() == 1) { - clear_cell(rs_cell); - return rs_cell.seq_nodes.end(); - } - // else update tree + // update the tree llama_rs_seq_node node = *node_iter; if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: because of this, partially removing seq_ids from cells should only be done from the tail cells[node.next_cell].prev = rs_cell.prev; } if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { @@ -2321,6 +2278,14 @@ struct llama_rs_cache { GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); prev_node->next_cell = node.next_cell; if (node.is_tail()) { + if (prev_cell.seq_nodes.size() > 1) { + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells += 1; + } + if (rs_cell.seq_nodes.size() == 1) { + n_seqs -= 1; + } + } prev_cell.tail_rc += 1; } } @@ -2328,11 +2293,15 @@ struct llama_rs_cache { auto & seq = seq_tails[node.seq_id]; if (node.is_tail()) { seq.tail = rs_cell.prev; - if (seq.tail < 0 && rs_cell.tail_rc == 1) { - // assuming the previous cell of a shared cell is also shared, - // (no need to update the shared tail cells count elsewhere, then) - // this was a shared tail cell, but will no longer be a tail cell - n_shared_tail_cells -= 1; + if (rs_cell.tail_rc == 1) { + if (rs_cell.seq_nodes.size() > 1) { + // assuming the previous cell of a shared cell is also shared, + // this was a shared tail cell, but will no longer be a tail cell + n_shared_tail_cells -= 1; + } else if (seq.tail < 0) { + // no more tail, no more sequence + n_seqs -= 1; + } } GGML_ASSERT(rs_cell.tail_rc > 0); rs_cell.tail_rc -= 1; @@ -2341,21 +2310,30 @@ struct llama_rs_cache { // this seq_id was the first in the list seq.n_cells -= 1; - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = *(std::next(node_iter)); - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - // only the tail ref count from the other seq_ids are left in tail_rc - if (rs_cell.tail_rc > 0) { - // will become a non-shared cell - if (rs_cell.seq_nodes.size() == 2) { - n_seqs += 1; + auto next_node = std::next(node_iter); + if (next_node != rs_cell.seq_nodes.end()) { + // the next node is the new first one, so update its n_cells + if ((uint32_t) next_node->seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node->seq_id]; + next_seq.n_cells += 1; + // only the tail ref count from the other seq_ids are left in tail_rc + if (rs_cell.tail_rc > 0) { + // will become a non-shared cell + if (rs_cell.seq_nodes.size() == 2) { + n_shared_tail_cells -= 1; + n_seqs += 1; + } } + } else { + GGML_ASSERT(false && "invalid seq_id"); } } else { - GGML_ASSERT(false && "invalid seq_id"); + // this was the last seq_id of the cell + used -= 1; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + // the other fields *should* have already been updated elsewhere } } } else { @@ -2366,6 +2344,13 @@ struct llama_rs_cache { return node_iter; } + void clear_cell(llama_rs_cell & rs_cell) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { + node_iter = remove_seq_node_from_cell(rs_cell, node_iter); + } + } + // returns whether or not the seq_id was removed bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < size) { @@ -2404,47 +2389,63 @@ struct llama_rs_cache { prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; rs_cell.prev = prev; + if (seq.tail == prev) { + // What to do when the tail moves... + // from unique to shared (n_seqs--) + // if the new cell has one seq_id or has no tails (n_shared_tail_cells++) + // if the new cell has one seq_id and a tail (n_seqs-- (yes, another time)) + // from unique to unique (seq.n_cells++) + // from empty to unique (seq.n_cells++, n_seqs++) + // from empty to shared + // if the new cell only has one seq_id or has no tail (n_shared_tail_cells++) + // if the new cell only has one seq_id and has one tail (n_seqs--) + // from shared to shared + // if the last cell has no tails (n_shared_tail_cells--) + // if the new cell has no tails or has one seq_id (n_shared_tail_cells++) + // if the new cell only has one seq_id and has one tail (n_seqs--) + // from shared to unique (seq.n_cells++) + // if this seq_id was not the first of the last cell (n_seqs++) + // if the last cell has no tails (n_shared_tail_cells--) + if (prev_cell.seq_nodes.size() > 1) { + // from shared + if (rs_cell.is_empty()) { + // to unique + if (prev_cell.seq_nodes[0].seq_id != id) { + n_seqs += 1; + } + } + // the previous cell is no longer a shared tail + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells -= 1; + } + } else if (!rs_cell.is_empty()) { + // from unique to shared + n_seqs -= 1; + } + } } if (rs_cell.is_empty()) { - // either the sequence didn't own any cells or had a shared tail cell - if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) { - n_seqs += 1; - } + // to unique seq.n_cells += 1; - // set pos if still unset - if (rs_cell.pos < 0) { + if (seq.tail < 0) { + // from empty to unique + n_seqs += 1; + // pos was not yet set rs_cell.pos = 0; rs_cell.src = -1; } used += 1; - } else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) { - // don't count shared-cell tails - // FIXME: make this saner - n_seqs -= 1; - n_shared_tail_cells += 1; - } else if (rs_cell.tail_rc == 0) { - // shared cell without a tail gets a tail; - // FIXME: don't prune, in case this is used in llama_cache_seq_cp - GGML_ASSERT(false); // make sure we don't get here by accident - // prune the other sequences out of this cell - // NOTE: have to inline the removal because the state tree is partially invalid - bool first = true; - for (auto & node : rs_cell.seq_nodes) { - GGML_ASSERT(node.seq_id != id); - GGML_ASSERT(node.next_cell >= 0); - // easy removal, none of the nodes are tails - llama_rs_cell & next_cell = cells[node.next_cell]; - next_cell.prev = rs_cell.prev; - if (first) { - auto & first_seq = seq_tails[node.seq_id]; - first_seq.n_cells -= 1; - first = false; + } else { + // to shared + if (rs_cell.seq_nodes.size() == 1) { + // a lone tail becomes a shared cell + if (rs_cell.tail_rc > 0) { + n_seqs -= 1; } + n_shared_tail_cells += 1; + } else if (rs_cell.tail_rc == 0) { + n_shared_tail_cells += 1; } - rs_cell.seq_nodes.clear(); - } else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) { - // this is correct as long as this isn't called when trying to find a slot - // TODO: find a way to assert this } // the target cell was not already a tail of this seq_id rs_cell.insert_node(id); // next_cell == -1 by default @@ -2977,6 +2978,7 @@ static bool llama_kv_cache_find_slot( llama_rs_cell & candidate = cache.rs.cells[cell_id]; if (candidate.is_empty()) { break; } if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + // the candidate is the old tail if (candidate.seq_nodes.size() > 1) { // prune out the other seq_ids, because they diverge // TODO(maybe): hande this in insert_seq_tail_to_cell_id @@ -3198,40 +3200,42 @@ static llama_pos llama_cache_seq_rm( llama_pos new_p0 = 0; llama_pos new_p1 = std::numeric_limits::max(); - for (uint32_t i = 0; i < cache.rs.size; ++i) { - llama_rs_cell & rs_cell = cache.rs.cells[i]; - auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + // partial seq_id removal has to happen from the tail + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + // copy before the cell is potentially changed + int32_t prev_id = rs_cell.prev; + if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) { + // non-tail removal for shared cells can only be done when clearing a cell + // (i.e. when the next cell's link to the previous cell can be safely changed) + p1 = rs_cell.pos + 1; + } + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + // if the node isn't found, the sequence tree is malformed + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + // get the smallest removed cell id + if (new_head > (uint32_t) cell_id) { new_head = cell_id; } + } else { + // one more than the biggest non-removed cell of this sequence + if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; } - if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) { if (rs_cell.pos < p0) { - // move forward the new p0 further - if (rs_cell.pos >= new_p0) { - new_p0 = rs_cell.pos + 1; - } - } else if (rs_cell.pos >= p1) { - // move back the new p1 further - if (rs_cell.pos < new_p1) { - new_p1 = rs_cell.pos; - } - if (rs_cell.pos >= n_past) { - n_past = rs_cell.pos + 1; - } - } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) - if (seq_id < 0) { - cache.rs.clear_cell(rs_cell); - } else { // (rs_cell.has_seq_id(seq_id)) - cache.rs.remove_seq_node_from_cell(rs_cell, seq_node); - } - if (rs_cell.is_empty() && new_head == cache.rs.size) { - new_head = i; - } + // new_p0 should be right after the max pos in the states before p0 + if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; } + } else { // (rs_cell.pos >= p1) + // new_p1 should be the min pos in the states after p1 + if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; } } } + cell_id = prev_id; } p0 = new_p0; p1 = new_p1; - // correctly set n_past when there's nothing after p1 - if (n_past < p0) { n_past = p0; } // If we freed up a slot, set head to it so searching can start there. if (new_head != cache.rs.size && new_head < cache.rs.head) { @@ -3259,10 +3263,8 @@ static llama_pos llama_cache_seq_rm( kv_cell.pos = -1; if (new_head == cache.kv.size) { new_head = i; } } - } else { - if (kv_cell.pos >= n_past) { - n_past = kv_cell.pos + 1; - } + } else if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; } } } @@ -3292,42 +3294,37 @@ static llama_pos llama_cache_seq_cp( llama_pos n_past = 0; if (cache.rs.size > 0) { - // have to start from beginning for recurrent models + // have to start from the beginning for recurrent models p0 = 0; if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { - auto seq_src = cache.rs.seq_tails[seq_id_src]; - int32_t src_tail = seq_src.tail; - // find the last tail of src in the pos range - while (src_tail >= 0 && (uint32_t) src_tail < cache.rs.size) { - llama_rs_cell & tail_cell = cache.rs.cells[src_tail]; - if (tail_cell.pos < p1) { - break; - } - src_tail = tail_cell.prev; - } - - uint32_t new_head = cache.rs.size; - + int32_t src_head = -1; + int32_t head_pos = p1; + int32_t src_next = -1; + // find the start of the sequence for (uint32_t i = 0; i < cache.rs.size; ++i) { llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { - if (i == (uint32_t) src_tail) { - // need to be inserted in order, but there's only one - cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst); - } else { - // keep only the tail cell of the source - // assuming a copy means no rollback will be attempted afterwards - cache.rs.remove_seq_from_cell_id(i, seq_id_src); - if (new_head == cache.rs.size) { - new_head = i; - } + if (!rs_cell.is_empty() && rs_cell.prev < 0) { + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + if (seq_node != rs_cell.seq_nodes.end()) { + src_head = i; + head_pos = rs_cell.pos; + src_next = seq_node->next_cell; + break; } } } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.rs.size && new_head < cache.rs.head) { - cache.rs.head = new_head; + while (src_head >= 0 && head_pos < p1) { + cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst); + src_head = src_next; + if (head_pos >= n_past) { n_past = head_pos + 1; } + if (src_next >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[src_next]; + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + head_pos = rs_cell.pos; + // it should always be found if the seq tree is valid + GGML_ASSERT(seq_node != rs_cell.seq_nodes.end()); + src_next = seq_node->next_cell; + } } } p1 = n_past; @@ -3338,9 +3335,7 @@ static llama_pos llama_cache_seq_cp( llama_kv_cell & kv_cell = cache.kv.cells[i]; if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { kv_cell.seq_id.insert(seq_id_dst); - if (kv_cell.pos >= n_past) { - n_past = kv_cell.pos + 1; - } + if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; } } } } @@ -3352,18 +3347,19 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id if (cache.rs.size > 0) { uint32_t new_head = cache.rs.size; - for (uint32_t i = 0; i < cache.rs.size; ++i) { - llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (!rs_cell.seq_nodes.empty()) { - for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { - if (node_iter->seq_id != seq_id) { - node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); - } else { - node_iter = std::next(node_iter); - } - } - if (new_head == cache.rs.size && rs_cell.is_empty()) { - new_head = i; + // partial seq_id removal has to happen from the tail(s) + for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { + if (i == (uint32_t) seq_id) { continue; } + llama_rs_seq_meta & seq = cache.rs.seq_tails[i]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i); + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + cell_id = rs_cell.prev; + if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) { + new_head = cell_id; } } } @@ -3414,6 +3410,7 @@ static llama_pos llama_cache_seq_add( auto & seq = cache.rs.seq_tails[seq_id]; // follow the sequence from its tail int32_t cell_id = seq.tail; + uint32_t new_head = cache.rs.size; while (cell_id >= 0) { GGML_ASSERT((uint32_t) cell_id < cache.rs.size); llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; @@ -3423,13 +3420,19 @@ static llama_pos llama_cache_seq_add( if (rs_cell.pos < 0) { // NOTE: this affects the other sequences which share the cell cache.rs.clear_cell(rs_cell); - // TODO: update cache.rs.head + if (new_head > (uint32_t) cell_id) { + new_head = cell_id; + } } } if (n_past <= rs_cell.pos) { n_past = rs_cell.pos + 1; } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.rs.head = new_head != cache.rs.size ? new_head : 0; } if (cache.kv.size > 0) { @@ -3474,8 +3477,8 @@ static llama_pos llama_cache_seq_div( llama_pos p0, llama_pos p1, int d) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -11275,6 +11278,10 @@ static int llama_decode_internal( } } n_outputs_prev += lctx.n_outputs; + +#ifndef NDEBUG + GGML_ASSERT(lctx.cache.rs.rebuild(true)); +#endif } // wait for the computation to finish (automatically done when obtaining the model output) @@ -16332,11 +16339,19 @@ void llama_batch_free(struct llama_batch batch) { int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + return ret; }