mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-04 01:57:53 +01:00
llama : use std::find for seq_nodes in llama_rs_cache
This commit is contained in:
parent
271104c65c
commit
8db1e4d45f
165
llama.cpp
165
llama.cpp
@ -1962,11 +1962,12 @@ struct llama_rs_seq_node {
|
||||
llama_seq_id seq_id = -1;
|
||||
int32_t next_cell = -1;
|
||||
|
||||
// needed for automatic typecasting with .find()
|
||||
// needed for automatic typecasting from a llama_seq_id
|
||||
llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {}
|
||||
|
||||
bool operator<(const llama_rs_seq_node & other) const {
|
||||
return seq_id < other.seq_id;
|
||||
// needed for more convenient std::find
|
||||
bool operator==(const llama_rs_seq_node & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
|
||||
bool is_tail() const {
|
||||
@ -1989,48 +1990,18 @@ struct llama_rs_cell {
|
||||
// seq_ids by insertion order, to simplify updating n_cells compared to a set
|
||||
std::vector<llama_rs_seq_node> seq_nodes;
|
||||
|
||||
llama_rs_seq_node * get_node(const llama_seq_id & id) {
|
||||
for (size_t i = 0; i < seq_nodes.size(); ++i) {
|
||||
if (seq_nodes[i].seq_id == id) {
|
||||
return &seq_nodes[i];
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void insert_node(const llama_rs_seq_node & node) {
|
||||
llama_rs_seq_node * node_dest = get_node(node.seq_id);
|
||||
if (node_dest == nullptr) {
|
||||
auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node);
|
||||
if (node_dest == seq_nodes.end()) {
|
||||
seq_nodes.push_back(node);
|
||||
} else {
|
||||
// overwrite the pre-existing node with the same seq_id if it exists
|
||||
*node_dest = node;
|
||||
}
|
||||
}
|
||||
|
||||
bool remove_node(llama_rs_seq_node * node_ptr) {
|
||||
if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) {
|
||||
size_t offset = node_ptr - seq_nodes.data();
|
||||
if (offset % sizeof(llama_rs_seq_node) == 0) {
|
||||
offset /= sizeof(llama_rs_seq_node);
|
||||
if (offset < seq_nodes.size()) {
|
||||
for (size_t i = offset + 1; i < seq_nodes.size(); ++i) {
|
||||
seq_nodes[i - 1] = seq_nodes[i];
|
||||
}
|
||||
seq_nodes.resize(seq_nodes.size() - 1);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
for (size_t i = 0; i < seq_nodes.size(); ++i) {
|
||||
if (seq_nodes[i].seq_id == id) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
@ -2132,67 +2103,65 @@ struct llama_rs_cache {
|
||||
bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) {
|
||||
if (i_cell < size && (size_t) id < size) {
|
||||
llama_rs_cell & rs_cell = cells[i_cell];
|
||||
auto * node_ptr = rs_cell.get_node(id); // search once
|
||||
if (node_ptr != nullptr) {
|
||||
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
|
||||
if (node_iter != rs_cell.seq_nodes.end()) {
|
||||
if (rs_cell.seq_nodes.size() == 1) {
|
||||
return clear_cell(i_cell);
|
||||
} else {
|
||||
// update tree
|
||||
llama_rs_seq_node node = *node_ptr;
|
||||
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
|
||||
cells[node.next_cell].prev = rs_cell.prev;
|
||||
}
|
||||
if ((uint32_t) node.seq_id < seq_tails.size()) {
|
||||
auto & seq = seq_tails[node.seq_id];
|
||||
bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2;
|
||||
if (node.is_tail()) {
|
||||
seq.tail = rs_cell.prev;
|
||||
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
|
||||
llama_rs_cell & new_tail = cells[seq.tail];
|
||||
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
|
||||
new_tail.tail_rc += 1;
|
||||
seq.shared = cells[seq.tail].seq_nodes.size() > 1;
|
||||
} else {
|
||||
seq.shared = false;
|
||||
}
|
||||
GGML_ASSERT(rs_cell.tail_rc > 0);
|
||||
rs_cell.tail_rc -= 1;
|
||||
}
|
||||
if (node_ptr == rs_cell.seq_nodes.data()) {
|
||||
// this seq_id was the first in the list
|
||||
seq.n_cells -= 1;
|
||||
if (seq.n_cells == 0) {
|
||||
n_seqs -= 1;
|
||||
}
|
||||
// the next node is the new first one, so update its n_cells
|
||||
// (will never be out-of-bounds because the size is > 1)
|
||||
llama_rs_seq_node next_node = node_ptr[1];
|
||||
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
|
||||
auto & next_seq = seq_tails[next_node.seq_id];
|
||||
next_seq.n_cells += 1;
|
||||
if (next_seq.n_cells == 1) {
|
||||
n_seqs += 1;
|
||||
}
|
||||
if (other_no_longer_shared) {
|
||||
next_seq.shared = false;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
} else if (other_no_longer_shared) {
|
||||
llama_rs_seq_node first_node = rs_cell.seq_nodes[0];
|
||||
if ((uint32_t) first_node.seq_id < seq_tails.size()) {
|
||||
seq_tails[first_node.seq_id].shared = false;
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
const bool removed = rs_cell.remove_node(node_ptr);
|
||||
GGML_ASSERT(removed);
|
||||
}
|
||||
// else update tree
|
||||
llama_rs_seq_node node = *node_iter;
|
||||
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
|
||||
cells[node.next_cell].prev = rs_cell.prev;
|
||||
}
|
||||
if ((uint32_t) node.seq_id < seq_tails.size()) {
|
||||
auto & seq = seq_tails[node.seq_id];
|
||||
bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2;
|
||||
if (node.is_tail()) {
|
||||
seq.tail = rs_cell.prev;
|
||||
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
|
||||
llama_rs_cell & new_tail = cells[seq.tail];
|
||||
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
|
||||
new_tail.tail_rc += 1;
|
||||
seq.shared = cells[seq.tail].seq_nodes.size() > 1;
|
||||
} else {
|
||||
seq.shared = false;
|
||||
}
|
||||
GGML_ASSERT(rs_cell.tail_rc > 0);
|
||||
rs_cell.tail_rc -= 1;
|
||||
}
|
||||
if (node_iter == rs_cell.seq_nodes.begin()) {
|
||||
// this seq_id was the first in the list
|
||||
seq.n_cells -= 1;
|
||||
if (seq.n_cells == 0) {
|
||||
n_seqs -= 1;
|
||||
}
|
||||
// the next node is the new first one, so update its n_cells
|
||||
// (will never be out-of-bounds because the size is > 1)
|
||||
llama_rs_seq_node next_node = *(std::next(node_iter));
|
||||
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
|
||||
auto & next_seq = seq_tails[next_node.seq_id];
|
||||
next_seq.n_cells += 1;
|
||||
if (next_seq.n_cells == 1) {
|
||||
n_seqs += 1;
|
||||
}
|
||||
if (other_no_longer_shared) {
|
||||
next_seq.shared = false;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
} else if (other_no_longer_shared) {
|
||||
llama_rs_seq_node first_node = rs_cell.seq_nodes[0];
|
||||
if ((uint32_t) first_node.seq_id < seq_tails.size()) {
|
||||
seq_tails[first_node.seq_id].shared = false;
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
rs_cell.seq_nodes.erase(node_iter);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
@ -2215,8 +2184,8 @@ struct llama_rs_cache {
|
||||
if (prev >= 0 && (uint32_t) prev < size) {
|
||||
// the targeted cell has a previous cell
|
||||
llama_rs_cell & prev_cell = cells[prev];
|
||||
llama_rs_seq_node * prev_node = prev_cell.get_node(id);
|
||||
GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing
|
||||
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id);
|
||||
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing
|
||||
GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken
|
||||
if (rs_cell.pos < 0) {
|
||||
GGML_ASSERT(rs_cell.is_empty());
|
||||
@ -2267,7 +2236,7 @@ struct llama_rs_cache {
|
||||
int32_t n_system_seqs = 0;
|
||||
int32_t n_system_cells = 0;
|
||||
for (size_t i = 0; i < seq_tails.size(); ++i) {
|
||||
auto & seq = seq_tails[i];
|
||||
const auto & seq = seq_tails[i];
|
||||
if (seq.tail >= 0 && (size_t) seq.tail < size) {
|
||||
if (seq.shared && seq.n_cells > 0) {
|
||||
n_system_seqs += 1;
|
||||
|
Loading…
Reference in New Issue
Block a user