llama : use std::find for seq_nodes in llama_rs_cache

This commit is contained in:
Francis Couture-Harpin 2024-04-04 10:46:43 -04:00
parent 271104c65c
commit 8db1e4d45f

View File

@ -1962,11 +1962,12 @@ struct llama_rs_seq_node {
llama_seq_id seq_id = -1; llama_seq_id seq_id = -1;
int32_t next_cell = -1; int32_t next_cell = -1;
// needed for automatic typecasting with .find() // needed for automatic typecasting from a llama_seq_id
llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {}
bool operator<(const llama_rs_seq_node & other) const { // needed for more convenient std::find
return seq_id < other.seq_id; bool operator==(const llama_rs_seq_node & other) const {
return seq_id == other.seq_id;
} }
bool is_tail() const { bool is_tail() const {
@ -1989,48 +1990,18 @@ struct llama_rs_cell {
// seq_ids by insertion order, to simplify updating n_cells compared to a set // seq_ids by insertion order, to simplify updating n_cells compared to a set
std::vector<llama_rs_seq_node> seq_nodes; std::vector<llama_rs_seq_node> seq_nodes;
llama_rs_seq_node * get_node(const llama_seq_id & id) {
for (size_t i = 0; i < seq_nodes.size(); ++i) {
if (seq_nodes[i].seq_id == id) {
return &seq_nodes[i];
}
}
return nullptr;
}
void insert_node(const llama_rs_seq_node & node) { void insert_node(const llama_rs_seq_node & node) {
llama_rs_seq_node * node_dest = get_node(node.seq_id); auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node);
if (node_dest == nullptr) { if (node_dest == seq_nodes.end()) {
seq_nodes.push_back(node); seq_nodes.push_back(node);
} else { } else {
// overwrite the pre-existing node with the same seq_id if it exists
*node_dest = node; *node_dest = node;
} }
} }
bool remove_node(llama_rs_seq_node * node_ptr) {
if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) {
size_t offset = node_ptr - seq_nodes.data();
if (offset % sizeof(llama_rs_seq_node) == 0) {
offset /= sizeof(llama_rs_seq_node);
if (offset < seq_nodes.size()) {
for (size_t i = offset + 1; i < seq_nodes.size(); ++i) {
seq_nodes[i - 1] = seq_nodes[i];
}
seq_nodes.resize(seq_nodes.size() - 1);
return true;
}
}
}
return false;
}
bool has_seq_id(const llama_seq_id & id) const { bool has_seq_id(const llama_seq_id & id) const {
for (size_t i = 0; i < seq_nodes.size(); ++i) { return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end();
if (seq_nodes[i].seq_id == id) {
return true;
}
}
return false;
} }
bool is_empty() const { bool is_empty() const {
@ -2132,13 +2103,13 @@ struct llama_rs_cache {
bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) {
if (i_cell < size && (size_t) id < size) { if (i_cell < size && (size_t) id < size) {
llama_rs_cell & rs_cell = cells[i_cell]; llama_rs_cell & rs_cell = cells[i_cell];
auto * node_ptr = rs_cell.get_node(id); // search once auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
if (node_ptr != nullptr) { if (node_iter != rs_cell.seq_nodes.end()) {
if (rs_cell.seq_nodes.size() == 1) { if (rs_cell.seq_nodes.size() == 1) {
return clear_cell(i_cell); return clear_cell(i_cell);
} else { }
// update tree // else update tree
llama_rs_seq_node node = *node_ptr; llama_rs_seq_node node = *node_iter;
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
cells[node.next_cell].prev = rs_cell.prev; cells[node.next_cell].prev = rs_cell.prev;
} }
@ -2158,7 +2129,7 @@ struct llama_rs_cache {
GGML_ASSERT(rs_cell.tail_rc > 0); GGML_ASSERT(rs_cell.tail_rc > 0);
rs_cell.tail_rc -= 1; rs_cell.tail_rc -= 1;
} }
if (node_ptr == rs_cell.seq_nodes.data()) { if (node_iter == rs_cell.seq_nodes.begin()) {
// this seq_id was the first in the list // this seq_id was the first in the list
seq.n_cells -= 1; seq.n_cells -= 1;
if (seq.n_cells == 0) { if (seq.n_cells == 0) {
@ -2166,7 +2137,7 @@ struct llama_rs_cache {
} }
// the next node is the new first one, so update its n_cells // the next node is the new first one, so update its n_cells
// (will never be out-of-bounds because the size is > 1) // (will never be out-of-bounds because the size is > 1)
llama_rs_seq_node next_node = node_ptr[1]; llama_rs_seq_node next_node = *(std::next(node_iter));
if ((uint32_t) next_node.seq_id < seq_tails.size()) { if ((uint32_t) next_node.seq_id < seq_tails.size()) {
auto & next_seq = seq_tails[next_node.seq_id]; auto & next_seq = seq_tails[next_node.seq_id];
next_seq.n_cells += 1; next_seq.n_cells += 1;
@ -2190,9 +2161,7 @@ struct llama_rs_cache {
} else { } else {
GGML_ASSERT(false && "invalid seq_id"); GGML_ASSERT(false && "invalid seq_id");
} }
const bool removed = rs_cell.remove_node(node_ptr); rs_cell.seq_nodes.erase(node_iter);
GGML_ASSERT(removed);
}
} }
} }
return false; return false;
@ -2215,8 +2184,8 @@ struct llama_rs_cache {
if (prev >= 0 && (uint32_t) prev < size) { if (prev >= 0 && (uint32_t) prev < size) {
// the targeted cell has a previous cell // the targeted cell has a previous cell
llama_rs_cell & prev_cell = cells[prev]; llama_rs_cell & prev_cell = cells[prev];
llama_rs_seq_node * prev_node = prev_cell.get_node(id); auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id);
GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing
GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken
if (rs_cell.pos < 0) { if (rs_cell.pos < 0) {
GGML_ASSERT(rs_cell.is_empty()); GGML_ASSERT(rs_cell.is_empty());
@ -2267,7 +2236,7 @@ struct llama_rs_cache {
int32_t n_system_seqs = 0; int32_t n_system_seqs = 0;
int32_t n_system_cells = 0; int32_t n_system_cells = 0;
for (size_t i = 0; i < seq_tails.size(); ++i) { for (size_t i = 0; i < seq_tails.size(); ++i) {
auto & seq = seq_tails[i]; const auto & seq = seq_tails[i];
if (seq.tail >= 0 && (size_t) seq.tail < size) { if (seq.tail >= 0 && (size_t) seq.tail < size) {
if (seq.shared && seq.n_cells > 0) { if (seq.shared && seq.n_cells > 0) {
n_system_seqs += 1; n_system_seqs += 1;