mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-22 09:39:08 +01:00
llama : correctly handle more edge cases for the rs cache
This commit is contained in:
parent
0028010d01
commit
0c8b3b2095
375
llama.cpp
375
llama.cpp
@ -2034,7 +2034,7 @@ struct llama_rs_cache {
|
|||||||
uint32_t n = 0; // range of states used for the last slot
|
uint32_t n = 0; // range of states used for the last slot
|
||||||
|
|
||||||
// useful to know the minimum reserved cell count per seq_id
|
// useful to know the minimum reserved cell count per seq_id
|
||||||
// only counts sequences with n_cells > 0 AND which have a non-shared tail
|
// only counts sequences which have a non-shared tail
|
||||||
uint32_t n_seqs = 0;
|
uint32_t n_seqs = 0;
|
||||||
// cells part of multiple sequences AND which have at least one tail
|
// cells part of multiple sequences AND which have at least one tail
|
||||||
uint32_t n_shared_tail_cells = 0;
|
uint32_t n_shared_tail_cells = 0;
|
||||||
@ -2082,21 +2082,37 @@ struct llama_rs_cache {
|
|||||||
llama_rs_cell & cell = cells[cell_id];
|
llama_rs_cell & cell = cells[cell_id];
|
||||||
if (cell.seq_nodes.empty()) {
|
if (cell.seq_nodes.empty()) {
|
||||||
if (cell.pos >= 0) {
|
if (cell.pos >= 0) {
|
||||||
|
if (debug) {
|
||||||
|
LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n",
|
||||||
|
__func__, cell_id, cell.pos);
|
||||||
|
}
|
||||||
cell.pos = -1;
|
cell.pos = -1;
|
||||||
was_valid = false;
|
was_valid = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (cell.pos < 0) {
|
if (cell.pos < 0) {
|
||||||
if (cell.pos != -1) {
|
if (cell.pos != -1) {
|
||||||
|
if (debug) {
|
||||||
|
LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n",
|
||||||
|
__func__, cell_id, cell.pos);
|
||||||
|
}
|
||||||
cell.pos = -1;
|
cell.pos = -1;
|
||||||
was_valid = false;
|
was_valid = false;
|
||||||
}
|
}
|
||||||
if (!cell.seq_nodes.empty()) {
|
if (!cell.seq_nodes.empty()) {
|
||||||
|
if (debug) {
|
||||||
|
LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n",
|
||||||
|
__func__, cell_id, cell.seq_nodes.size());
|
||||||
|
}
|
||||||
cell.seq_nodes.clear();
|
cell.seq_nodes.clear();
|
||||||
was_valid = false;
|
was_valid = false;
|
||||||
}
|
}
|
||||||
cell.src = -1;
|
cell.src = -1;
|
||||||
if (cell.prev != -1) {
|
if (cell.prev != -1) {
|
||||||
|
if (debug) {
|
||||||
|
LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n",
|
||||||
|
__func__, cell_id, cell.prev);
|
||||||
|
}
|
||||||
cell.prev = -1;
|
cell.prev = -1;
|
||||||
was_valid = false;
|
was_valid = false;
|
||||||
}
|
}
|
||||||
@ -2213,17 +2229,15 @@ struct llama_rs_cache {
|
|||||||
// n_seqs
|
// n_seqs
|
||||||
uint32_t n_seqs_verif = 0;
|
uint32_t n_seqs_verif = 0;
|
||||||
uint32_t n_shared_tail_cells_verif = 0;
|
uint32_t n_shared_tail_cells_verif = 0;
|
||||||
for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) {
|
for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) {
|
||||||
auto & seq = seq_tails[seq_id];
|
llama_rs_cell & rs_cell = cells[cell_id];
|
||||||
if (seq.tail >= 0) {
|
if (!rs_cell.seq_nodes.empty()) {
|
||||||
llama_rs_cell & tail_cell = cells[seq.tail];
|
if (rs_cell.seq_nodes.size() == 1) {
|
||||||
// NOTE: could also have checked if n_cells > 0
|
if (rs_cell.tail_rc == 1) {
|
||||||
if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) {
|
|
||||||
if (tail_cell.seq_nodes.size() > 1) {
|
|
||||||
n_shared_tail_cells_verif += 1;
|
|
||||||
} else {
|
|
||||||
n_seqs_verif += 1;
|
n_seqs_verif += 1;
|
||||||
}
|
}
|
||||||
|
} else if (rs_cell.tail_rc > 0) {
|
||||||
|
n_shared_tail_cells_verif += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2246,72 +2260,15 @@ struct llama_rs_cache {
|
|||||||
return was_valid;
|
return was_valid;
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns whether or not a cell was freed
|
|
||||||
void clear_cell(llama_rs_cell & rs_cell) {
|
|
||||||
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
|
|
||||||
if (!rs_cell.is_empty()) {
|
|
||||||
// update sequence tree links
|
|
||||||
bool first = true;
|
|
||||||
for (const llama_rs_seq_node & node : rs_cell.seq_nodes) {
|
|
||||||
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
|
|
||||||
// NOTE: if all next cells are the same cell, this should still work
|
|
||||||
cells[node.next_cell].prev = rs_cell.prev;
|
|
||||||
}
|
|
||||||
// next_cell of the nodes of the previous cell
|
|
||||||
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
|
|
||||||
llama_rs_cell & prev_cell = cells[rs_cell.prev];
|
|
||||||
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node);
|
|
||||||
// assuming the previous node is always found
|
|
||||||
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
|
|
||||||
prev_node->next_cell = node.next_cell;
|
|
||||||
if (node.is_tail()) {
|
|
||||||
prev_cell.tail_rc += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ((uint32_t) node.seq_id < seq_tails.size()) {
|
|
||||||
auto & seq = seq_tails[node.seq_id];
|
|
||||||
// update tail
|
|
||||||
if (node.is_tail()) {
|
|
||||||
seq.tail = rs_cell.prev;
|
|
||||||
}
|
|
||||||
// cell counts
|
|
||||||
if (first) {
|
|
||||||
seq.n_cells -= 1;
|
|
||||||
if (rs_cell.tail_rc > 0 && seq.tail < 0) {
|
|
||||||
// last tail cell
|
|
||||||
if (rs_cell.seq_nodes.size() > 1) {
|
|
||||||
n_shared_tail_cells -= 1;
|
|
||||||
} else {
|
|
||||||
n_seqs -= 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false && "invalid seq_id");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rs_cell.pos = -1;
|
|
||||||
rs_cell.src = -1;
|
|
||||||
rs_cell.prev = -1;
|
|
||||||
rs_cell.tail_rc = 0;
|
|
||||||
rs_cell.seq_nodes.clear();
|
|
||||||
used -= 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed.
|
// returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed.
|
||||||
std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) {
|
std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) {
|
||||||
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
|
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
|
||||||
// TODO: assert the iterator points inside the correct vector
|
// TODO: assert the iterator points inside the correct vector
|
||||||
if (node_iter != rs_cell.seq_nodes.end()) {
|
if (node_iter != rs_cell.seq_nodes.end()) {
|
||||||
if (rs_cell.seq_nodes.size() == 1) {
|
// update the tree
|
||||||
clear_cell(rs_cell);
|
|
||||||
return rs_cell.seq_nodes.end();
|
|
||||||
}
|
|
||||||
// else update tree
|
|
||||||
llama_rs_seq_node node = *node_iter;
|
llama_rs_seq_node node = *node_iter;
|
||||||
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
|
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
|
||||||
|
// NOTE: because of this, partially removing seq_ids from cells should only be done from the tail
|
||||||
cells[node.next_cell].prev = rs_cell.prev;
|
cells[node.next_cell].prev = rs_cell.prev;
|
||||||
}
|
}
|
||||||
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
|
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
|
||||||
@ -2321,6 +2278,14 @@ struct llama_rs_cache {
|
|||||||
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
|
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
|
||||||
prev_node->next_cell = node.next_cell;
|
prev_node->next_cell = node.next_cell;
|
||||||
if (node.is_tail()) {
|
if (node.is_tail()) {
|
||||||
|
if (prev_cell.seq_nodes.size() > 1) {
|
||||||
|
if (prev_cell.tail_rc == 0) {
|
||||||
|
n_shared_tail_cells += 1;
|
||||||
|
}
|
||||||
|
if (rs_cell.seq_nodes.size() == 1) {
|
||||||
|
n_seqs -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
prev_cell.tail_rc += 1;
|
prev_cell.tail_rc += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2328,11 +2293,15 @@ struct llama_rs_cache {
|
|||||||
auto & seq = seq_tails[node.seq_id];
|
auto & seq = seq_tails[node.seq_id];
|
||||||
if (node.is_tail()) {
|
if (node.is_tail()) {
|
||||||
seq.tail = rs_cell.prev;
|
seq.tail = rs_cell.prev;
|
||||||
if (seq.tail < 0 && rs_cell.tail_rc == 1) {
|
if (rs_cell.tail_rc == 1) {
|
||||||
|
if (rs_cell.seq_nodes.size() > 1) {
|
||||||
// assuming the previous cell of a shared cell is also shared,
|
// assuming the previous cell of a shared cell is also shared,
|
||||||
// (no need to update the shared tail cells count elsewhere, then)
|
|
||||||
// this was a shared tail cell, but will no longer be a tail cell
|
// this was a shared tail cell, but will no longer be a tail cell
|
||||||
n_shared_tail_cells -= 1;
|
n_shared_tail_cells -= 1;
|
||||||
|
} else if (seq.tail < 0) {
|
||||||
|
// no more tail, no more sequence
|
||||||
|
n_seqs -= 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
GGML_ASSERT(rs_cell.tail_rc > 0);
|
GGML_ASSERT(rs_cell.tail_rc > 0);
|
||||||
rs_cell.tail_rc -= 1;
|
rs_cell.tail_rc -= 1;
|
||||||
@ -2341,22 +2310,31 @@ struct llama_rs_cache {
|
|||||||
// this seq_id was the first in the list
|
// this seq_id was the first in the list
|
||||||
seq.n_cells -= 1;
|
seq.n_cells -= 1;
|
||||||
|
|
||||||
|
auto next_node = std::next(node_iter);
|
||||||
|
if (next_node != rs_cell.seq_nodes.end()) {
|
||||||
// the next node is the new first one, so update its n_cells
|
// the next node is the new first one, so update its n_cells
|
||||||
// (will never be out-of-bounds because the size is > 1)
|
if ((uint32_t) next_node->seq_id < seq_tails.size()) {
|
||||||
llama_rs_seq_node next_node = *(std::next(node_iter));
|
auto & next_seq = seq_tails[next_node->seq_id];
|
||||||
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
|
|
||||||
auto & next_seq = seq_tails[next_node.seq_id];
|
|
||||||
next_seq.n_cells += 1;
|
next_seq.n_cells += 1;
|
||||||
// only the tail ref count from the other seq_ids are left in tail_rc
|
// only the tail ref count from the other seq_ids are left in tail_rc
|
||||||
if (rs_cell.tail_rc > 0) {
|
if (rs_cell.tail_rc > 0) {
|
||||||
// will become a non-shared cell
|
// will become a non-shared cell
|
||||||
if (rs_cell.seq_nodes.size() == 2) {
|
if (rs_cell.seq_nodes.size() == 2) {
|
||||||
|
n_shared_tail_cells -= 1;
|
||||||
n_seqs += 1;
|
n_seqs += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "invalid seq_id");
|
GGML_ASSERT(false && "invalid seq_id");
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// this was the last seq_id of the cell
|
||||||
|
used -= 1;
|
||||||
|
rs_cell.pos = -1;
|
||||||
|
rs_cell.src = -1;
|
||||||
|
rs_cell.prev = -1;
|
||||||
|
// the other fields *should* have already been updated elsewhere
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "invalid seq_id");
|
GGML_ASSERT(false && "invalid seq_id");
|
||||||
@ -2366,6 +2344,13 @@ struct llama_rs_cache {
|
|||||||
return node_iter;
|
return node_iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void clear_cell(llama_rs_cell & rs_cell) {
|
||||||
|
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
|
||||||
|
for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) {
|
||||||
|
node_iter = remove_seq_node_from_cell(rs_cell, node_iter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// returns whether or not the seq_id was removed
|
// returns whether or not the seq_id was removed
|
||||||
bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) {
|
bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) {
|
||||||
if (i_cell < size && (size_t) id < size) {
|
if (i_cell < size && (size_t) id < size) {
|
||||||
@ -2404,48 +2389,64 @@ struct llama_rs_cache {
|
|||||||
prev_cell.tail_rc -= 1;
|
prev_cell.tail_rc -= 1;
|
||||||
prev_node->next_cell = i_cell;
|
prev_node->next_cell = i_cell;
|
||||||
rs_cell.prev = prev;
|
rs_cell.prev = prev;
|
||||||
}
|
if (seq.tail == prev) {
|
||||||
|
// What to do when the tail moves...
|
||||||
|
// from unique to shared (n_seqs--)
|
||||||
|
// if the new cell has one seq_id or has no tails (n_shared_tail_cells++)
|
||||||
|
// if the new cell has one seq_id and a tail (n_seqs-- (yes, another time))
|
||||||
|
// from unique to unique (seq.n_cells++)
|
||||||
|
// from empty to unique (seq.n_cells++, n_seqs++)
|
||||||
|
// from empty to shared
|
||||||
|
// if the new cell only has one seq_id or has no tail (n_shared_tail_cells++)
|
||||||
|
// if the new cell only has one seq_id and has one tail (n_seqs--)
|
||||||
|
// from shared to shared
|
||||||
|
// if the last cell has no tails (n_shared_tail_cells--)
|
||||||
|
// if the new cell has no tails or has one seq_id (n_shared_tail_cells++)
|
||||||
|
// if the new cell only has one seq_id and has one tail (n_seqs--)
|
||||||
|
// from shared to unique (seq.n_cells++)
|
||||||
|
// if this seq_id was not the first of the last cell (n_seqs++)
|
||||||
|
// if the last cell has no tails (n_shared_tail_cells--)
|
||||||
|
if (prev_cell.seq_nodes.size() > 1) {
|
||||||
|
// from shared
|
||||||
if (rs_cell.is_empty()) {
|
if (rs_cell.is_empty()) {
|
||||||
// either the sequence didn't own any cells or had a shared tail cell
|
// to unique
|
||||||
if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) {
|
if (prev_cell.seq_nodes[0].seq_id != id) {
|
||||||
n_seqs += 1;
|
n_seqs += 1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
// the previous cell is no longer a shared tail
|
||||||
|
if (prev_cell.tail_rc == 0) {
|
||||||
|
n_shared_tail_cells -= 1;
|
||||||
|
}
|
||||||
|
} else if (!rs_cell.is_empty()) {
|
||||||
|
// from unique to shared
|
||||||
|
n_seqs -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (rs_cell.is_empty()) {
|
||||||
|
// to unique
|
||||||
seq.n_cells += 1;
|
seq.n_cells += 1;
|
||||||
// set pos if still unset
|
if (seq.tail < 0) {
|
||||||
if (rs_cell.pos < 0) {
|
// from empty to unique
|
||||||
|
n_seqs += 1;
|
||||||
|
// pos was not yet set
|
||||||
rs_cell.pos = 0;
|
rs_cell.pos = 0;
|
||||||
rs_cell.src = -1;
|
rs_cell.src = -1;
|
||||||
}
|
}
|
||||||
used += 1;
|
used += 1;
|
||||||
} else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) {
|
} else {
|
||||||
// don't count shared-cell tails
|
// to shared
|
||||||
// FIXME: make this saner
|
if (rs_cell.seq_nodes.size() == 1) {
|
||||||
|
// a lone tail becomes a shared cell
|
||||||
|
if (rs_cell.tail_rc > 0) {
|
||||||
n_seqs -= 1;
|
n_seqs -= 1;
|
||||||
|
}
|
||||||
n_shared_tail_cells += 1;
|
n_shared_tail_cells += 1;
|
||||||
} else if (rs_cell.tail_rc == 0) {
|
} else if (rs_cell.tail_rc == 0) {
|
||||||
// shared cell without a tail gets a tail;
|
n_shared_tail_cells += 1;
|
||||||
// FIXME: don't prune, in case this is used in llama_cache_seq_cp
|
|
||||||
GGML_ASSERT(false); // make sure we don't get here by accident
|
|
||||||
// prune the other sequences out of this cell
|
|
||||||
// NOTE: have to inline the removal because the state tree is partially invalid
|
|
||||||
bool first = true;
|
|
||||||
for (auto & node : rs_cell.seq_nodes) {
|
|
||||||
GGML_ASSERT(node.seq_id != id);
|
|
||||||
GGML_ASSERT(node.next_cell >= 0);
|
|
||||||
// easy removal, none of the nodes are tails
|
|
||||||
llama_rs_cell & next_cell = cells[node.next_cell];
|
|
||||||
next_cell.prev = rs_cell.prev;
|
|
||||||
if (first) {
|
|
||||||
auto & first_seq = seq_tails[node.seq_id];
|
|
||||||
first_seq.n_cells -= 1;
|
|
||||||
first = false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rs_cell.seq_nodes.clear();
|
|
||||||
} else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) {
|
|
||||||
// this is correct as long as this isn't called when trying to find a slot
|
|
||||||
// TODO: find a way to assert this
|
|
||||||
}
|
|
||||||
// the target cell was not already a tail of this seq_id
|
// the target cell was not already a tail of this seq_id
|
||||||
rs_cell.insert_node(id); // next_cell == -1 by default
|
rs_cell.insert_node(id); // next_cell == -1 by default
|
||||||
rs_cell.tail_rc += 1;
|
rs_cell.tail_rc += 1;
|
||||||
@ -2977,6 +2978,7 @@ static bool llama_kv_cache_find_slot(
|
|||||||
llama_rs_cell & candidate = cache.rs.cells[cell_id];
|
llama_rs_cell & candidate = cache.rs.cells[cell_id];
|
||||||
if (candidate.is_empty()) { break; }
|
if (candidate.is_empty()) { break; }
|
||||||
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
|
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
|
||||||
|
// the candidate is the old tail
|
||||||
if (candidate.seq_nodes.size() > 1) {
|
if (candidate.seq_nodes.size() > 1) {
|
||||||
// prune out the other seq_ids, because they diverge
|
// prune out the other seq_ids, because they diverge
|
||||||
// TODO(maybe): hande this in insert_seq_tail_to_cell_id
|
// TODO(maybe): hande this in insert_seq_tail_to_cell_id
|
||||||
@ -3198,40 +3200,42 @@ static llama_pos llama_cache_seq_rm(
|
|||||||
llama_pos new_p0 = 0;
|
llama_pos new_p0 = 0;
|
||||||
llama_pos new_p1 = std::numeric_limits<llama_pos>::max();
|
llama_pos new_p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.rs.size; ++i) {
|
// partial seq_id removal has to happen from the tail
|
||||||
llama_rs_cell & rs_cell = cache.rs.cells[i];
|
llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id];
|
||||||
auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id);
|
int32_t cell_id = seq.tail;
|
||||||
|
|
||||||
|
while (cell_id >= 0) {
|
||||||
|
llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
|
||||||
|
// copy before the cell is potentially changed
|
||||||
|
int32_t prev_id = rs_cell.prev;
|
||||||
|
if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) {
|
||||||
|
// non-tail removal for shared cells can only be done when clearing a cell
|
||||||
|
// (i.e. when the next cell's link to the previous cell can be safely changed)
|
||||||
|
p1 = rs_cell.pos + 1;
|
||||||
|
}
|
||||||
|
if (rs_cell.pos >= p0 && rs_cell.pos < p1) {
|
||||||
|
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id);
|
||||||
|
// if the node isn't found, the sequence tree is malformed
|
||||||
|
GGML_ASSERT(node_iter != rs_cell.seq_nodes.end());
|
||||||
|
cache.rs.remove_seq_node_from_cell(rs_cell, node_iter);
|
||||||
|
// get the smallest removed cell id
|
||||||
|
if (new_head > (uint32_t) cell_id) { new_head = cell_id; }
|
||||||
|
} else {
|
||||||
|
// one more than the biggest non-removed cell of this sequence
|
||||||
|
if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; }
|
||||||
|
|
||||||
if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) {
|
|
||||||
if (rs_cell.pos < p0) {
|
if (rs_cell.pos < p0) {
|
||||||
// move forward the new p0 further
|
// new_p0 should be right after the max pos in the states before p0
|
||||||
if (rs_cell.pos >= new_p0) {
|
if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; }
|
||||||
new_p0 = rs_cell.pos + 1;
|
} else { // (rs_cell.pos >= p1)
|
||||||
}
|
// new_p1 should be the min pos in the states after p1
|
||||||
} else if (rs_cell.pos >= p1) {
|
if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; }
|
||||||
// move back the new p1 further
|
|
||||||
if (rs_cell.pos < new_p1) {
|
|
||||||
new_p1 = rs_cell.pos;
|
|
||||||
}
|
|
||||||
if (rs_cell.pos >= n_past) {
|
|
||||||
n_past = rs_cell.pos + 1;
|
|
||||||
}
|
|
||||||
} else { // (rs_cell.pos >= p0 && rs_cell.pos < p1)
|
|
||||||
if (seq_id < 0) {
|
|
||||||
cache.rs.clear_cell(rs_cell);
|
|
||||||
} else { // (rs_cell.has_seq_id(seq_id))
|
|
||||||
cache.rs.remove_seq_node_from_cell(rs_cell, seq_node);
|
|
||||||
}
|
|
||||||
if (rs_cell.is_empty() && new_head == cache.rs.size) {
|
|
||||||
new_head = i;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
cell_id = prev_id;
|
||||||
}
|
}
|
||||||
p0 = new_p0;
|
p0 = new_p0;
|
||||||
p1 = new_p1;
|
p1 = new_p1;
|
||||||
// correctly set n_past when there's nothing after p1
|
|
||||||
if (n_past < p0) { n_past = p0; }
|
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
if (new_head != cache.rs.size && new_head < cache.rs.head) {
|
if (new_head != cache.rs.size && new_head < cache.rs.head) {
|
||||||
@ -3259,13 +3263,11 @@ static llama_pos llama_cache_seq_rm(
|
|||||||
kv_cell.pos = -1;
|
kv_cell.pos = -1;
|
||||||
if (new_head == cache.kv.size) { new_head = i; }
|
if (new_head == cache.kv.size) { new_head = i; }
|
||||||
}
|
}
|
||||||
} else {
|
} else if (kv_cell.pos >= n_past) {
|
||||||
if (kv_cell.pos >= n_past) {
|
|
||||||
n_past = kv_cell.pos + 1;
|
n_past = kv_cell.pos + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
if (new_head != cache.kv.size && new_head < cache.kv.head) {
|
if (new_head != cache.kv.size && new_head < cache.kv.head) {
|
||||||
@ -3292,42 +3294,37 @@ static llama_pos llama_cache_seq_cp(
|
|||||||
llama_pos n_past = 0;
|
llama_pos n_past = 0;
|
||||||
|
|
||||||
if (cache.rs.size > 0) {
|
if (cache.rs.size > 0) {
|
||||||
// have to start from beginning for recurrent models
|
// have to start from the beginning for recurrent models
|
||||||
p0 = 0;
|
p0 = 0;
|
||||||
if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) {
|
if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) {
|
||||||
auto seq_src = cache.rs.seq_tails[seq_id_src];
|
int32_t src_head = -1;
|
||||||
int32_t src_tail = seq_src.tail;
|
int32_t head_pos = p1;
|
||||||
// find the last tail of src in the pos range
|
int32_t src_next = -1;
|
||||||
while (src_tail >= 0 && (uint32_t) src_tail < cache.rs.size) {
|
// find the start of the sequence
|
||||||
llama_rs_cell & tail_cell = cache.rs.cells[src_tail];
|
|
||||||
if (tail_cell.pos < p1) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
src_tail = tail_cell.prev;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t new_head = cache.rs.size;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.rs.size; ++i) {
|
for (uint32_t i = 0; i < cache.rs.size; ++i) {
|
||||||
llama_rs_cell & rs_cell = cache.rs.cells[i];
|
llama_rs_cell & rs_cell = cache.rs.cells[i];
|
||||||
if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) {
|
if (!rs_cell.is_empty() && rs_cell.prev < 0) {
|
||||||
if (i == (uint32_t) src_tail) {
|
auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src);
|
||||||
// need to be inserted in order, but there's only one
|
if (seq_node != rs_cell.seq_nodes.end()) {
|
||||||
cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst);
|
src_head = i;
|
||||||
} else {
|
head_pos = rs_cell.pos;
|
||||||
// keep only the tail cell of the source
|
src_next = seq_node->next_cell;
|
||||||
// assuming a copy means no rollback will be attempted afterwards
|
break;
|
||||||
cache.rs.remove_seq_from_cell_id(i, seq_id_src);
|
|
||||||
if (new_head == cache.rs.size) {
|
|
||||||
new_head = i;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
while (src_head >= 0 && head_pos < p1) {
|
||||||
|
cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst);
|
||||||
|
src_head = src_next;
|
||||||
|
if (head_pos >= n_past) { n_past = head_pos + 1; }
|
||||||
|
if (src_next >= 0) {
|
||||||
|
llama_rs_cell & rs_cell = cache.rs.cells[src_next];
|
||||||
|
auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src);
|
||||||
|
head_pos = rs_cell.pos;
|
||||||
|
// it should always be found if the seq tree is valid
|
||||||
|
GGML_ASSERT(seq_node != rs_cell.seq_nodes.end());
|
||||||
|
src_next = seq_node->next_cell;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
|
||||||
if (new_head != cache.rs.size && new_head < cache.rs.head) {
|
|
||||||
cache.rs.head = new_head;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p1 = n_past;
|
p1 = n_past;
|
||||||
@ -3338,9 +3335,7 @@ static llama_pos llama_cache_seq_cp(
|
|||||||
llama_kv_cell & kv_cell = cache.kv.cells[i];
|
llama_kv_cell & kv_cell = cache.kv.cells[i];
|
||||||
if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) {
|
if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) {
|
||||||
kv_cell.seq_id.insert(seq_id_dst);
|
kv_cell.seq_id.insert(seq_id_dst);
|
||||||
if (kv_cell.pos >= n_past) {
|
if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; }
|
||||||
n_past = kv_cell.pos + 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3352,18 +3347,19 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id
|
|||||||
if (cache.rs.size > 0) {
|
if (cache.rs.size > 0) {
|
||||||
uint32_t new_head = cache.rs.size;
|
uint32_t new_head = cache.rs.size;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.rs.size; ++i) {
|
// partial seq_id removal has to happen from the tail(s)
|
||||||
llama_rs_cell & rs_cell = cache.rs.cells[i];
|
for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) {
|
||||||
if (!rs_cell.seq_nodes.empty()) {
|
if (i == (uint32_t) seq_id) { continue; }
|
||||||
for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) {
|
llama_rs_seq_meta & seq = cache.rs.seq_tails[i];
|
||||||
if (node_iter->seq_id != seq_id) {
|
int32_t cell_id = seq.tail;
|
||||||
node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter);
|
while (cell_id >= 0) {
|
||||||
} else {
|
llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
|
||||||
node_iter = std::next(node_iter);
|
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i);
|
||||||
}
|
GGML_ASSERT(node_iter != rs_cell.seq_nodes.end());
|
||||||
}
|
cache.rs.remove_seq_node_from_cell(rs_cell, node_iter);
|
||||||
if (new_head == cache.rs.size && rs_cell.is_empty()) {
|
cell_id = rs_cell.prev;
|
||||||
new_head = i;
|
if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) {
|
||||||
|
new_head = cell_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3414,6 +3410,7 @@ static llama_pos llama_cache_seq_add(
|
|||||||
auto & seq = cache.rs.seq_tails[seq_id];
|
auto & seq = cache.rs.seq_tails[seq_id];
|
||||||
// follow the sequence from its tail
|
// follow the sequence from its tail
|
||||||
int32_t cell_id = seq.tail;
|
int32_t cell_id = seq.tail;
|
||||||
|
uint32_t new_head = cache.rs.size;
|
||||||
while (cell_id >= 0) {
|
while (cell_id >= 0) {
|
||||||
GGML_ASSERT((uint32_t) cell_id < cache.rs.size);
|
GGML_ASSERT((uint32_t) cell_id < cache.rs.size);
|
||||||
llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
|
llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
|
||||||
@ -3423,13 +3420,19 @@ static llama_pos llama_cache_seq_add(
|
|||||||
if (rs_cell.pos < 0) {
|
if (rs_cell.pos < 0) {
|
||||||
// NOTE: this affects the other sequences which share the cell
|
// NOTE: this affects the other sequences which share the cell
|
||||||
cache.rs.clear_cell(rs_cell);
|
cache.rs.clear_cell(rs_cell);
|
||||||
// TODO: update cache.rs.head
|
if (new_head > (uint32_t) cell_id) {
|
||||||
|
new_head = cell_id;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (n_past <= rs_cell.pos) {
|
if (n_past <= rs_cell.pos) {
|
||||||
n_past = rs_cell.pos + 1;
|
n_past = rs_cell.pos + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
|
// Otherwise we just start the next search from the beginning.
|
||||||
|
cache.rs.head = new_head != cache.rs.size ? new_head : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cache.kv.size > 0) {
|
if (cache.kv.size > 0) {
|
||||||
@ -3474,8 +3477,8 @@ static llama_pos llama_cache_seq_div(
|
|||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
int d) {
|
int d) {
|
||||||
if (p0 < 0) p0 = 0;
|
if (p0 < 0) { p0 = 0; }
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }
|
||||||
|
|
||||||
llama_pos n_past = p0;
|
llama_pos n_past = p0;
|
||||||
|
|
||||||
@ -11275,6 +11278,10 @@ static int llama_decode_internal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
n_outputs_prev += lctx.n_outputs;
|
n_outputs_prev += lctx.n_outputs;
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_ASSERT(lctx.cache.rs.rebuild(true));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||||
@ -16332,11 +16339,19 @@ void llama_batch_free(struct llama_batch batch) {
|
|||||||
int32_t llama_decode(
|
int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch) {
|
struct llama_batch batch) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_ASSERT(ctx->cache.rs.rebuild(true));
|
||||||
|
#endif
|
||||||
|
|
||||||
const int ret = llama_decode_internal(*ctx, batch);
|
const int ret = llama_decode_internal(*ctx, batch);
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_ASSERT(ctx->cache.rs.rebuild(true));
|
||||||
|
#endif
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user