llama : use std::find for seq_nodes in llama_rs_cache

This commit is contained in:
Francis Couture-Harpin 2024-04-04 10:46:43 -04:00
parent 271104c65c
commit 8db1e4d45f

165
llama.cpp
View File

@ -1962,11 +1962,12 @@ struct llama_rs_seq_node {
llama_seq_id seq_id = -1;
int32_t next_cell = -1;
// needed for automatic typecasting with .find()
// needed for automatic typecasting from a llama_seq_id
llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {}
bool operator<(const llama_rs_seq_node & other) const {
return seq_id < other.seq_id;
// needed for more convenient std::find
bool operator==(const llama_rs_seq_node & other) const {
return seq_id == other.seq_id;
}
bool is_tail() const {
@ -1989,48 +1990,18 @@ struct llama_rs_cell {
// seq_ids by insertion order, to simplify updating n_cells compared to a set
std::vector<llama_rs_seq_node> seq_nodes;
llama_rs_seq_node * get_node(const llama_seq_id & id) {
for (size_t i = 0; i < seq_nodes.size(); ++i) {
if (seq_nodes[i].seq_id == id) {
return &seq_nodes[i];
}
}
return nullptr;
}
void insert_node(const llama_rs_seq_node & node) {
llama_rs_seq_node * node_dest = get_node(node.seq_id);
if (node_dest == nullptr) {
auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node);
if (node_dest == seq_nodes.end()) {
seq_nodes.push_back(node);
} else {
// overwrite the pre-existing node with the same seq_id if it exists
*node_dest = node;
}
}
bool remove_node(llama_rs_seq_node * node_ptr) {
if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) {
size_t offset = node_ptr - seq_nodes.data();
if (offset % sizeof(llama_rs_seq_node) == 0) {
offset /= sizeof(llama_rs_seq_node);
if (offset < seq_nodes.size()) {
for (size_t i = offset + 1; i < seq_nodes.size(); ++i) {
seq_nodes[i - 1] = seq_nodes[i];
}
seq_nodes.resize(seq_nodes.size() - 1);
return true;
}
}
}
return false;
}
bool has_seq_id(const llama_seq_id & id) const {
for (size_t i = 0; i < seq_nodes.size(); ++i) {
if (seq_nodes[i].seq_id == id) {
return true;
}
}
return false;
return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end();
}
bool is_empty() const {
@ -2132,67 +2103,65 @@ struct llama_rs_cache {
bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) {
if (i_cell < size && (size_t) id < size) {
llama_rs_cell & rs_cell = cells[i_cell];
auto * node_ptr = rs_cell.get_node(id); // search once
if (node_ptr != nullptr) {
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
if (node_iter != rs_cell.seq_nodes.end()) {
if (rs_cell.seq_nodes.size() == 1) {
return clear_cell(i_cell);
} else {
// update tree
llama_rs_seq_node node = *node_ptr;
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
cells[node.next_cell].prev = rs_cell.prev;
}
if ((uint32_t) node.seq_id < seq_tails.size()) {
auto & seq = seq_tails[node.seq_id];
bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2;
if (node.is_tail()) {
seq.tail = rs_cell.prev;
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
llama_rs_cell & new_tail = cells[seq.tail];
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
new_tail.tail_rc += 1;
seq.shared = cells[seq.tail].seq_nodes.size() > 1;
} else {
seq.shared = false;
}
GGML_ASSERT(rs_cell.tail_rc > 0);
rs_cell.tail_rc -= 1;
}
if (node_ptr == rs_cell.seq_nodes.data()) {
// this seq_id was the first in the list
seq.n_cells -= 1;
if (seq.n_cells == 0) {
n_seqs -= 1;
}
// the next node is the new first one, so update its n_cells
// (will never be out-of-bounds because the size is > 1)
llama_rs_seq_node next_node = node_ptr[1];
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
auto & next_seq = seq_tails[next_node.seq_id];
next_seq.n_cells += 1;
if (next_seq.n_cells == 1) {
n_seqs += 1;
}
if (other_no_longer_shared) {
next_seq.shared = false;
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
} else if (other_no_longer_shared) {
llama_rs_seq_node first_node = rs_cell.seq_nodes[0];
if ((uint32_t) first_node.seq_id < seq_tails.size()) {
seq_tails[first_node.seq_id].shared = false;
} else {
GGML_ASSERT(false && "invalid seq_id");
}
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
const bool removed = rs_cell.remove_node(node_ptr);
GGML_ASSERT(removed);
}
// else update tree
llama_rs_seq_node node = *node_iter;
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
cells[node.next_cell].prev = rs_cell.prev;
}
if ((uint32_t) node.seq_id < seq_tails.size()) {
auto & seq = seq_tails[node.seq_id];
bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2;
if (node.is_tail()) {
seq.tail = rs_cell.prev;
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
llama_rs_cell & new_tail = cells[seq.tail];
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
new_tail.tail_rc += 1;
seq.shared = cells[seq.tail].seq_nodes.size() > 1;
} else {
seq.shared = false;
}
GGML_ASSERT(rs_cell.tail_rc > 0);
rs_cell.tail_rc -= 1;
}
if (node_iter == rs_cell.seq_nodes.begin()) {
// this seq_id was the first in the list
seq.n_cells -= 1;
if (seq.n_cells == 0) {
n_seqs -= 1;
}
// the next node is the new first one, so update its n_cells
// (will never be out-of-bounds because the size is > 1)
llama_rs_seq_node next_node = *(std::next(node_iter));
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
auto & next_seq = seq_tails[next_node.seq_id];
next_seq.n_cells += 1;
if (next_seq.n_cells == 1) {
n_seqs += 1;
}
if (other_no_longer_shared) {
next_seq.shared = false;
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
} else if (other_no_longer_shared) {
llama_rs_seq_node first_node = rs_cell.seq_nodes[0];
if ((uint32_t) first_node.seq_id < seq_tails.size()) {
seq_tails[first_node.seq_id].shared = false;
} else {
GGML_ASSERT(false && "invalid seq_id");
}
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
rs_cell.seq_nodes.erase(node_iter);
}
}
return false;
@ -2215,8 +2184,8 @@ struct llama_rs_cache {
if (prev >= 0 && (uint32_t) prev < size) {
// the targeted cell has a previous cell
llama_rs_cell & prev_cell = cells[prev];
llama_rs_seq_node * prev_node = prev_cell.get_node(id);
GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id);
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing
GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken
if (rs_cell.pos < 0) {
GGML_ASSERT(rs_cell.is_empty());
@ -2267,7 +2236,7 @@ struct llama_rs_cache {
int32_t n_system_seqs = 0;
int32_t n_system_cells = 0;
for (size_t i = 0; i < seq_tails.size(); ++i) {
auto & seq = seq_tails[i];
const auto & seq = seq_tails[i];
if (seq.tail >= 0 && (size_t) seq.tail < size) {
if (seq.shared && seq.n_cells > 0) {
n_system_seqs += 1;