mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 11:23:56 +01:00
llama : use std::find for seq_nodes in llama_rs_cache
This commit is contained in:
parent
271104c65c
commit
8db1e4d45f
69
llama.cpp
69
llama.cpp
@ -1962,11 +1962,12 @@ struct llama_rs_seq_node {
|
|||||||
llama_seq_id seq_id = -1;
|
llama_seq_id seq_id = -1;
|
||||||
int32_t next_cell = -1;
|
int32_t next_cell = -1;
|
||||||
|
|
||||||
// needed for automatic typecasting with .find()
|
// needed for automatic typecasting from a llama_seq_id
|
||||||
llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {}
|
llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {}
|
||||||
|
|
||||||
bool operator<(const llama_rs_seq_node & other) const {
|
// needed for more convenient std::find
|
||||||
return seq_id < other.seq_id;
|
bool operator==(const llama_rs_seq_node & other) const {
|
||||||
|
return seq_id == other.seq_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_tail() const {
|
bool is_tail() const {
|
||||||
@ -1989,48 +1990,18 @@ struct llama_rs_cell {
|
|||||||
// seq_ids by insertion order, to simplify updating n_cells compared to a set
|
// seq_ids by insertion order, to simplify updating n_cells compared to a set
|
||||||
std::vector<llama_rs_seq_node> seq_nodes;
|
std::vector<llama_rs_seq_node> seq_nodes;
|
||||||
|
|
||||||
llama_rs_seq_node * get_node(const llama_seq_id & id) {
|
|
||||||
for (size_t i = 0; i < seq_nodes.size(); ++i) {
|
|
||||||
if (seq_nodes[i].seq_id == id) {
|
|
||||||
return &seq_nodes[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void insert_node(const llama_rs_seq_node & node) {
|
void insert_node(const llama_rs_seq_node & node) {
|
||||||
llama_rs_seq_node * node_dest = get_node(node.seq_id);
|
auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node);
|
||||||
if (node_dest == nullptr) {
|
if (node_dest == seq_nodes.end()) {
|
||||||
seq_nodes.push_back(node);
|
seq_nodes.push_back(node);
|
||||||
} else {
|
} else {
|
||||||
|
// overwrite the pre-existing node with the same seq_id if it exists
|
||||||
*node_dest = node;
|
*node_dest = node;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool remove_node(llama_rs_seq_node * node_ptr) {
|
|
||||||
if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) {
|
|
||||||
size_t offset = node_ptr - seq_nodes.data();
|
|
||||||
if (offset % sizeof(llama_rs_seq_node) == 0) {
|
|
||||||
offset /= sizeof(llama_rs_seq_node);
|
|
||||||
if (offset < seq_nodes.size()) {
|
|
||||||
for (size_t i = offset + 1; i < seq_nodes.size(); ++i) {
|
|
||||||
seq_nodes[i - 1] = seq_nodes[i];
|
|
||||||
}
|
|
||||||
seq_nodes.resize(seq_nodes.size() - 1);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool has_seq_id(const llama_seq_id & id) const {
|
bool has_seq_id(const llama_seq_id & id) const {
|
||||||
for (size_t i = 0; i < seq_nodes.size(); ++i) {
|
return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end();
|
||||||
if (seq_nodes[i].seq_id == id) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_empty() const {
|
bool is_empty() const {
|
||||||
@ -2132,13 +2103,13 @@ struct llama_rs_cache {
|
|||||||
bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) {
|
bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) {
|
||||||
if (i_cell < size && (size_t) id < size) {
|
if (i_cell < size && (size_t) id < size) {
|
||||||
llama_rs_cell & rs_cell = cells[i_cell];
|
llama_rs_cell & rs_cell = cells[i_cell];
|
||||||
auto * node_ptr = rs_cell.get_node(id); // search once
|
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
|
||||||
if (node_ptr != nullptr) {
|
if (node_iter != rs_cell.seq_nodes.end()) {
|
||||||
if (rs_cell.seq_nodes.size() == 1) {
|
if (rs_cell.seq_nodes.size() == 1) {
|
||||||
return clear_cell(i_cell);
|
return clear_cell(i_cell);
|
||||||
} else {
|
}
|
||||||
// update tree
|
// else update tree
|
||||||
llama_rs_seq_node node = *node_ptr;
|
llama_rs_seq_node node = *node_iter;
|
||||||
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
|
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
|
||||||
cells[node.next_cell].prev = rs_cell.prev;
|
cells[node.next_cell].prev = rs_cell.prev;
|
||||||
}
|
}
|
||||||
@ -2158,7 +2129,7 @@ struct llama_rs_cache {
|
|||||||
GGML_ASSERT(rs_cell.tail_rc > 0);
|
GGML_ASSERT(rs_cell.tail_rc > 0);
|
||||||
rs_cell.tail_rc -= 1;
|
rs_cell.tail_rc -= 1;
|
||||||
}
|
}
|
||||||
if (node_ptr == rs_cell.seq_nodes.data()) {
|
if (node_iter == rs_cell.seq_nodes.begin()) {
|
||||||
// this seq_id was the first in the list
|
// this seq_id was the first in the list
|
||||||
seq.n_cells -= 1;
|
seq.n_cells -= 1;
|
||||||
if (seq.n_cells == 0) {
|
if (seq.n_cells == 0) {
|
||||||
@ -2166,7 +2137,7 @@ struct llama_rs_cache {
|
|||||||
}
|
}
|
||||||
// the next node is the new first one, so update its n_cells
|
// the next node is the new first one, so update its n_cells
|
||||||
// (will never be out-of-bounds because the size is > 1)
|
// (will never be out-of-bounds because the size is > 1)
|
||||||
llama_rs_seq_node next_node = node_ptr[1];
|
llama_rs_seq_node next_node = *(std::next(node_iter));
|
||||||
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
|
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
|
||||||
auto & next_seq = seq_tails[next_node.seq_id];
|
auto & next_seq = seq_tails[next_node.seq_id];
|
||||||
next_seq.n_cells += 1;
|
next_seq.n_cells += 1;
|
||||||
@ -2190,9 +2161,7 @@ struct llama_rs_cache {
|
|||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "invalid seq_id");
|
GGML_ASSERT(false && "invalid seq_id");
|
||||||
}
|
}
|
||||||
const bool removed = rs_cell.remove_node(node_ptr);
|
rs_cell.seq_nodes.erase(node_iter);
|
||||||
GGML_ASSERT(removed);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -2215,8 +2184,8 @@ struct llama_rs_cache {
|
|||||||
if (prev >= 0 && (uint32_t) prev < size) {
|
if (prev >= 0 && (uint32_t) prev < size) {
|
||||||
// the targeted cell has a previous cell
|
// the targeted cell has a previous cell
|
||||||
llama_rs_cell & prev_cell = cells[prev];
|
llama_rs_cell & prev_cell = cells[prev];
|
||||||
llama_rs_seq_node * prev_node = prev_cell.get_node(id);
|
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id);
|
||||||
GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing
|
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing
|
||||||
GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken
|
GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken
|
||||||
if (rs_cell.pos < 0) {
|
if (rs_cell.pos < 0) {
|
||||||
GGML_ASSERT(rs_cell.is_empty());
|
GGML_ASSERT(rs_cell.is_empty());
|
||||||
@ -2267,7 +2236,7 @@ struct llama_rs_cache {
|
|||||||
int32_t n_system_seqs = 0;
|
int32_t n_system_seqs = 0;
|
||||||
int32_t n_system_cells = 0;
|
int32_t n_system_cells = 0;
|
||||||
for (size_t i = 0; i < seq_tails.size(); ++i) {
|
for (size_t i = 0; i < seq_tails.size(); ++i) {
|
||||||
auto & seq = seq_tails[i];
|
const auto & seq = seq_tails[i];
|
||||||
if (seq.tail >= 0 && (size_t) seq.tail < size) {
|
if (seq.tail >= 0 && (size_t) seq.tail < size) {
|
||||||
if (seq.shared && seq.n_cells > 0) {
|
if (seq.shared && seq.n_cells > 0) {
|
||||||
n_system_seqs += 1;
|
n_system_seqs += 1;
|
||||||
|
Loading…
Reference in New Issue
Block a user