llama : state checkpoints for recurrent models

This commit is contained in:
Francis Couture-Harpin 2024-04-08 09:54:35 -04:00
parent 8db1e4d45f
commit 0028010d01
2 changed files with 592 additions and 269 deletions

74
ggml.c
View File

@ -6335,19 +6335,18 @@ struct ggml_tensor * ggml_ssm_conv(
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(ggml_is_matrix(x));
GGML_ASSERT(ggml_is_matrix(c));
GGML_ASSERT(ggml_is_matrix(sq));
GGML_ASSERT(ggml_is_vector(sq));
GGML_ASSERT(sq->type == GGML_TYPE_I32);
const int64_t d_conv = c->ne[0];
const int64_t d_inner = c->ne[1];
const int64_t n_tokens = x->ne[1];
const int64_t n_kv = s->ne[2];
const int64_t n_rs = s->ne[2];
GGML_ASSERT( s->ne[0] == d_conv - 1);
GGML_ASSERT( s->ne[1] == d_inner);
GGML_ASSERT( x->ne[0] == d_inner);
GGML_ASSERT(sq->ne[0] == n_kv);
GGML_ASSERT(sq->ne[1] == n_tokens);
GGML_ASSERT(sq->ne[0] == n_tokens);
bool is_node = false;
@ -6356,8 +6355,8 @@ struct ggml_tensor * ggml_ssm_conv(
is_node = true;
}
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs));
result->op = GGML_OP_SSM_CONV;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -6410,7 +6409,7 @@ struct ggml_tensor * ggml_ssm_scan(
is_node = true;
}
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
result->op = GGML_OP_SSM_SCAN;
@ -15087,9 +15086,9 @@ static void ggml_compute_forward_ssm_conv_f32(
const int nc = src2->ne[0]; // d_conv
const int nr = src0->ne[1]; // d_inner
const int n_t = src1->ne[1]; // n_tokens
const int n_kv = src0->ne[2]; // max number of sequences in the batch
const int n_rs = src0->ne[2]; // max number of sequences in the batch
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
@ -15106,10 +15105,12 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
if (n_kv > 1) {
const int32_t * sq = src3->data; // {n_tokens}
if (n_rs > 1) {
// multiple sequences means it's hard to know when it's the first time a state is read,
// so copy them all over to the destination, just to be sure.
for (int i3 = 0; i3 < n_kv; ++i3) {
for (int i3 = 0; i3 < n_rs; ++i3) {
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
// can't use memcpy because of d_conv vs d_conv - 1
@ -15123,19 +15124,19 @@ static void ggml_compute_forward_ssm_conv_f32(
}
for (int i2 = 0; i2 < n_t; ++i2) {
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
int32_t sq_i = sq[i2];
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
float * s0; // {d_conv - 1, d_inner, n_kv}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs}
float * s0; // {d_conv - 1, d_inner, n_rs}
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
int ne0s0;
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
// avoid needing to copy the state for the first token
if (i2 == 0) {
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs}
ne0s0 = src0->ne[0];
} else {
// the source is the last (d_conv - 1) columns of the destination
@ -15153,18 +15154,6 @@ static void ggml_compute_forward_ssm_conv_f32(
s[(nc - 1) + i1*nc] = x0[i1];
}
// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
}
}
// it seems a little faster when this is separate from the state shift
for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product
@ -15216,7 +15205,7 @@ static void ggml_compute_forward_ssm_scan_f32(
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
@ -15225,6 +15214,7 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
// required for the dot product between s and C, and when copying the states
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
@ -15240,10 +15230,12 @@ static void ggml_compute_forward_ssm_scan_f32(
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
if (n_kv > 1) {
const int32_t * sq = src6->data; // {n_tokens}
if (n_rs > 1) {
// it's hard to know if the source states have already been copied
// when there are multiple, so copy them already.
for (int i3 = 0; i3 < n_kv; ++i3) {
for (int i3 = 0; i3 < n_rs; ++i3) {
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
memcpy(s, s0, nc*ir*sizeof(float));
@ -15251,9 +15243,9 @@ static void ggml_compute_forward_ssm_scan_f32(
}
for (int i2 = 0; i2 < n_t; ++i2) {
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
int32_t sq_i = sq[i2];
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs}
float * s0;
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
@ -15261,11 +15253,11 @@ static void ggml_compute_forward_ssm_scan_f32(
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
// avoid needing to copy the state for the first token
if (i2 == 0) {
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs}
} else {
// otherwise the source is the same as the destination
s0 = s;
@ -15288,18 +15280,6 @@ static void ggml_compute_forward_ssm_scan_f32(
}
y[i1] = sumf;
}
// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
}
}
}
}

647
llama.cpp
View File

@ -2016,11 +2016,13 @@ struct llama_rs_seq_meta {
// number of cells for which this seq_id is the first
// (useful to know if cells in this sequence should be pruned)
int32_t n_cells = 0;
// whether the tail is a cell part of multiple sequences
bool shared = false;
// changing the tail cell of a sequence can only be done at batch boundary,
// this guards against changing the cell when it shouldn't be;
// should be cleared when done finding a slot
bool in_ubatch = false;
};
// ring-buffer of cached recurrent state data
// ring-buffered tree of cached recurrent state data
struct llama_rs_cache {
bool do_copy = false;
@ -2032,8 +2034,10 @@ struct llama_rs_cache {
uint32_t n = 0; // range of states used for the last slot
// useful to know the minimum reserved cell count per seq_id
// only counts sequences with n_cells > 0
// only counts sequences with n_cells > 0 AND which have a non-shared tail
uint32_t n_seqs = 0;
// cells part of multiple sequences AND which have at least one tail
uint32_t n_shared_tail_cells = 0;
// with state models, a cell can hold the state for more than one past token
// TODO: it's probably not possible to always use contiguous cells
@ -2047,10 +2051,204 @@ struct llama_rs_cache {
std::vector<struct ggml_tensor *> r_l; // rolling/shift states
std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
// TODO: maybe use a simpler data structure than a tree
// Inefficient, but thorough verification and rebuilding of the rs cache
// from only the cells list with `pos` and seq_ids.
// Should not be called in a hot loop except when desperate and/or debugging.
bool rebuild(bool debug) {
bool was_valid = true;
// the source of truth is the cells list
// buffer sizes
if (size != cells.size()) {
if (debug) {
LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n",
__func__, cells.size(), size);
}
cells.resize(size);
was_valid = false;
}
if (size != seq_tails.size()) {
if (debug) {
LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n",
__func__, seq_tails.size(), size);
}
seq_tails.resize(size);
was_valid = false;
}
// cells consistency
uint32_t used_verif = 0;
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
llama_rs_cell & cell = cells[cell_id];
if (cell.seq_nodes.empty()) {
if (cell.pos >= 0) {
cell.pos = -1;
was_valid = false;
}
}
if (cell.pos < 0) {
if (cell.pos != -1) {
cell.pos = -1;
was_valid = false;
}
if (!cell.seq_nodes.empty()) {
cell.seq_nodes.clear();
was_valid = false;
}
cell.src = -1;
if (cell.prev != -1) {
cell.prev = -1;
was_valid = false;
}
} else if (!debug) {
// Assuming the cache should be actually rebuilt when not debugging
cell.src = cell_id;
}
if (!cell.seq_nodes.empty()) {
used_verif += 1;
}
}
if (used != used_verif) {
if (debug) {
LLAMA_LOG_ERROR("%s: invalid used cell count (%u instead of %u)\n",
__func__, used, used_verif);
}
used = used_verif;
was_valid = false;
}
// tail verification
std::vector<std::pair<llama_pos, uint32_t>> seq_cells;
for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) {
auto & seq = seq_tails[seq_id];
seq_cells.clear();
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
llama_rs_cell & cell = cells[cell_id];
if (cell.has_seq_id(seq_id)) {
seq_cells.push_back({cell.pos, cell_id});
}
}
// sort by pos and then by cell_id
std::sort(seq_cells.begin(), seq_cells.end());
int32_t tail = seq_cells.empty() ? -1 : seq_cells[seq_cells.size() - 1].second;
if (tail != seq.tail) {
if (debug) {
LLAMA_LOG_ERROR("%s: invalid tail for seq_id %d (%d instead of %d)\n",
__func__, seq_id, seq.tail, tail);
}
seq.tail = tail;
was_valid = false;
}
int32_t prev = -1;
for (size_t i = 0; i < seq_cells.size(); ++i) {
uint32_t cell_id = seq_cells[i].second;
llama_rs_cell & cell = cells[cell_id];
if (cell.prev != prev) {
// TODO: relax the error when multiple cells have the same pos
if (debug) {
LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n",
__func__, cell_id, cell.prev, prev);
}
cell.prev = prev;
was_valid = false;
}
prev = cell_id;
}
int32_t n_cells = 0;
int32_t next = -1;
for (size_t i = seq_cells.size(); i-- > 0;) {
uint32_t cell_id = seq_cells[i].second;
llama_rs_cell & cell = cells[cell_id];
// assuming it's always found, because how else would it end up in the list of cells for this seq_id?
auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id);
if (seq_node == cell.seq_nodes.begin()) {
n_cells += 1;
}
if (seq_node->next_cell != next) {
// TODO: relax the error when multiple cells have the same pos
if (debug) {
LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n",
__func__, cell_id, seq_node->next_cell, next);
}
seq_node->next_cell = next;
was_valid = false;
}
next = cell_id;
}
if (seq.n_cells != n_cells) {
if (debug) {
LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n",
__func__, seq_id, seq.n_cells, n_cells);
}
seq.n_cells = n_cells;
}
// in_batch should only be true when in the process of finding a slot
if (seq.in_ubatch != false) {
if (debug) {
LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n",
__func__, seq_id);
}
seq.in_ubatch = false;
was_valid = false;
}
}
// tail_rc
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
llama_rs_cell & cell = cells[cell_id];
uint32_t tail_rc = 0;
for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) {
auto & seq = seq_tails[seq_id];
if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) {
tail_rc += 1;
}
}
if (cell.tail_rc != tail_rc) {
if (debug) {
LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n",
__func__, cell_id, cell.tail_rc, tail_rc);
}
cell.tail_rc = tail_rc;
was_valid = false;
}
}
// n_seqs
uint32_t n_seqs_verif = 0;
uint32_t n_shared_tail_cells_verif = 0;
for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) {
auto & seq = seq_tails[seq_id];
if (seq.tail >= 0) {
llama_rs_cell & tail_cell = cells[seq.tail];
// NOTE: could also have checked if n_cells > 0
if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) {
if (tail_cell.seq_nodes.size() > 1) {
n_shared_tail_cells_verif += 1;
} else {
n_seqs_verif += 1;
}
}
}
}
if (n_seqs != n_seqs_verif) {
if (debug) {
LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n",
__func__, n_seqs, n_seqs_verif);
}
n_seqs = n_seqs_verif;
was_valid = false;
}
if (n_shared_tail_cells != n_shared_tail_cells_verif) {
if (debug) {
LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n",
__func__, n_shared_tail_cells, n_shared_tail_cells_verif);
}
n_shared_tail_cells = n_shared_tail_cells_verif;
was_valid = false;
}
return was_valid;
}
// returns whether or not a cell was freed
bool clear_cell(uint32_t i) {
if (i < size) {
llama_rs_cell & rs_cell = cells[i];
void clear_cell(llama_rs_cell & rs_cell) {
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
if (!rs_cell.is_empty()) {
// update sequence tree links
bool first = true;
@ -2059,27 +2257,34 @@ struct llama_rs_cache {
// NOTE: if all next cells are the same cell, this should still work
cells[node.next_cell].prev = rs_cell.prev;
}
// next_cell of the nodes of the previous cell
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
llama_rs_cell & prev_cell = cells[rs_cell.prev];
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node);
// assuming the previous node is always found
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
prev_node->next_cell = node.next_cell;
if (node.is_tail()) {
prev_cell.tail_rc += 1;
}
}
if ((uint32_t) node.seq_id < seq_tails.size()) {
auto & seq = seq_tails[node.seq_id];
// update tail
if (node.is_tail()) {
seq.tail = rs_cell.prev;
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
llama_rs_cell & new_tail = cells[seq.tail];
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
new_tail.tail_rc += 1;
seq.shared = new_tail.seq_nodes.size() > 1;
} else {
seq.shared = false;
}
}
// cell counts
if (first) {
seq.n_cells -= 1;
if (seq.n_cells == 0) {
GGML_ASSERT(seq.tail < 0);
if (rs_cell.tail_rc > 0 && seq.tail < 0) {
// last tail cell
if (rs_cell.seq_nodes.size() > 1) {
n_shared_tail_cells -= 1;
} else {
n_seqs -= 1;
}
}
first = false;
}
} else {
@ -2092,39 +2297,42 @@ struct llama_rs_cache {
rs_cell.tail_rc = 0;
rs_cell.seq_nodes.clear();
used -= 1;
return true;
}
}
return false;
}
// TODO: maybe use a simpler data structure than a tree
// returns whether or not a cell was freed
bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) {
if (i_cell < size && (size_t) id < size) {
llama_rs_cell & rs_cell = cells[i_cell];
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
// returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed.
std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) {
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
// TODO: assert the iterator points inside the correct vector
if (node_iter != rs_cell.seq_nodes.end()) {
if (rs_cell.seq_nodes.size() == 1) {
return clear_cell(i_cell);
clear_cell(rs_cell);
return rs_cell.seq_nodes.end();
}
// else update tree
llama_rs_seq_node node = *node_iter;
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
cells[node.next_cell].prev = rs_cell.prev;
}
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
llama_rs_cell & prev_cell = cells[rs_cell.prev];
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node);
// assuming the previous node is always found
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
prev_node->next_cell = node.next_cell;
if (node.is_tail()) {
prev_cell.tail_rc += 1;
}
}
if ((uint32_t) node.seq_id < seq_tails.size()) {
auto & seq = seq_tails[node.seq_id];
bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2;
if (node.is_tail()) {
seq.tail = rs_cell.prev;
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
llama_rs_cell & new_tail = cells[seq.tail];
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
new_tail.tail_rc += 1;
seq.shared = cells[seq.tail].seq_nodes.size() > 1;
} else {
seq.shared = false;
if (seq.tail < 0 && rs_cell.tail_rc == 1) {
// assuming the previous cell of a shared cell is also shared,
// (no need to update the shared tail cells count elsewhere, then)
// this was a shared tail cell, but will no longer be a tail cell
n_shared_tail_cells -= 1;
}
GGML_ASSERT(rs_cell.tail_rc > 0);
rs_cell.tail_rc -= 1;
@ -2132,42 +2340,43 @@ struct llama_rs_cache {
if (node_iter == rs_cell.seq_nodes.begin()) {
// this seq_id was the first in the list
seq.n_cells -= 1;
if (seq.n_cells == 0) {
n_seqs -= 1;
}
// the next node is the new first one, so update its n_cells
// (will never be out-of-bounds because the size is > 1)
llama_rs_seq_node next_node = *(std::next(node_iter));
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
auto & next_seq = seq_tails[next_node.seq_id];
next_seq.n_cells += 1;
if (next_seq.n_cells == 1) {
// only the tail ref count from the other seq_ids are left in tail_rc
if (rs_cell.tail_rc > 0) {
// will become a non-shared cell
if (rs_cell.seq_nodes.size() == 2) {
n_seqs += 1;
}
if (other_no_longer_shared) {
next_seq.shared = false;
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
} else if (other_no_longer_shared) {
llama_rs_seq_node first_node = rs_cell.seq_nodes[0];
if ((uint32_t) first_node.seq_id < seq_tails.size()) {
seq_tails[first_node.seq_id].shared = false;
} else {
GGML_ASSERT(false && "invalid seq_id");
}
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
rs_cell.seq_nodes.erase(node_iter);
return rs_cell.seq_nodes.erase(node_iter);
}
return node_iter;
}
// returns whether or not the seq_id was removed
bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) {
if (i_cell < size && (size_t) id < size) {
llama_rs_cell & rs_cell = cells[i_cell];
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
return node_iter != remove_seq_node_from_cell(rs_cell, node_iter);
}
return false;
}
bool insert_seq_tail_to_cell(uint32_t i_cell, const llama_seq_id & id) {
bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) {
if (i_cell < size && (size_t) id < seq_tails.size()) {
llama_rs_cell & rs_cell = cells[i_cell];
auto & seq = seq_tails[id];
@ -2194,10 +2403,11 @@ struct llama_rs_cache {
}
prev_cell.tail_rc -= 1;
prev_node->next_cell = i_cell;
rs_cell.prev = prev;
}
if (rs_cell.is_empty()) {
// only add after potential failures above
if (seq.n_cells == 0) {
// either the sequence didn't own any cells or had a shared tail cell
if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) {
n_seqs += 1;
}
seq.n_cells += 1;
@ -2206,12 +2416,40 @@ struct llama_rs_cache {
rs_cell.pos = 0;
rs_cell.src = -1;
}
used += 1;
} else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) {
// don't count shared-cell tails
// FIXME: make this saner
n_seqs -= 1;
n_shared_tail_cells += 1;
} else if (rs_cell.tail_rc == 0) {
// shared cell without a tail gets a tail;
// FIXME: don't prune, in case this is used in llama_cache_seq_cp
GGML_ASSERT(false); // make sure we don't get here by accident
// prune the other sequences out of this cell
// NOTE: have to inline the removal because the state tree is partially invalid
bool first = true;
for (auto & node : rs_cell.seq_nodes) {
GGML_ASSERT(node.seq_id != id);
GGML_ASSERT(node.next_cell >= 0);
// easy removal, none of the nodes are tails
llama_rs_cell & next_cell = cells[node.next_cell];
next_cell.prev = rs_cell.prev;
if (first) {
auto & first_seq = seq_tails[node.seq_id];
first_seq.n_cells -= 1;
first = false;
}
}
rs_cell.seq_nodes.clear();
} else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) {
// this is correct as long as this isn't called when trying to find a slot
// TODO: find a way to assert this
}
// the target cell was not already a tail of this seq_id
rs_cell.insert_node(id); // next_cell == -1 by default
rs_cell.tail_rc += 1;
seq.tail = i_cell;
seq.shared = rs_cell.seq_nodes.size() > 1;
return true;
}
return false;
@ -2219,33 +2457,12 @@ struct llama_rs_cache {
// each seq_id should have access to at least this many cells
// (to use when pruning (to avoid over-pruning))
// (but this over-prunes when the system prompt doesn't take lots of cells)
// Hmm. The system prompt does not need checkpoints...
size_t min_cells_per_seq() const {
return size / (n_seqs > 0 ? n_seqs : 1);
size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const {
uint32_t seqs = n_seqs;
if (new_seq.tail < 0 || new_seq.n_cells == 0) {
seqs += 1;
}
// each seq_id can have at most this many cells
// (ignoring seqs which behave as a shared prompt)
// TODO: avoid recalculating system seq_ids
// (to use when pruning (to avoid over-pruning))
// NOTE: this also limits the shared prompt to at most half the cells
// (but the shared prompt technically needs only one cell...)
// (IDEA: keep only one cell when `llama_kv_cache_seq_cp` is called on a sequence)
size_t max_cells_per_seq() const {
int32_t n_system_seqs = 0;
int32_t n_system_cells = 0;
for (size_t i = 0; i < seq_tails.size(); ++i) {
const auto & seq = seq_tails[i];
if (seq.tail >= 0 && (size_t) seq.tail < size) {
if (seq.shared && seq.n_cells > 0) {
n_system_seqs += 1;
n_system_cells += seq.n_cells;
}
}
}
int32_t n_other_seqs = n_seqs - n_system_seqs;
return (size - n_system_cells) / (n_other_seqs > 0 ? n_other_seqs : 1);
return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1);
}
size_t total_size() const {
@ -2528,7 +2745,7 @@ struct llama_context {
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [n_rs]
struct ggml_tensor * inp_s_mask; // F32 [1, n_rs]
struct ggml_tensor * inp_s_seq; // I32 [n_rs, n_batch]
struct ggml_tensor * inp_s_seq; // I32 [n_batch]
// control vectors
struct llama_control_vector cvec;
@ -2657,7 +2874,7 @@ static bool llama_cache_init(
return false;
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
cache.bufs.push_back(buf);
}
@ -2678,54 +2895,170 @@ static bool llama_kv_cache_find_slot(
if (rs_size > 0) {
// For recurrent state architectures (like Mamba),
// each cache cell can store the state for a whole sequence.
// TODO: real ring-buffer of states
// TODO: state chekpoints (multiple cells per sequence)
// TODO: find a way to always make the rs slot contiguous
// Okay, need to find a slot. Everything should fit assuming the biggest seq_id < rs_size
llama_seq_id min = cache.rs.size - 1;
llama_seq_id max = 0;
llama_seq_id min_seq = cache.rs.size - 1;
llama_seq_id max_seq = 0;
uint32_t min_cell = cache.rs.size - 1;
uint32_t max_cell = 0;
for (uint32_t i = 0; i < n_tokens; ++i) {
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
int32_t target_cell = -1; // ensure all the sequences of a token get the same cell
int32_t n_seq_ids = batch.n_seq_id[i];
for (int32_t j = 0; j < n_seq_ids; ++j) {
llama_seq_id seq_id = batch.seq_id[i][j];
// make sure it's a valid seq_id
bool need_new_cell = false;
// Everything should fit assuming the biggest seq_id < rs_size
if ((uint32_t) seq_id < rs_size) {
if (seq_id > max) {
max = seq_id;
llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id];
if (seq_id > max_seq) { max_seq = seq_id; }
if (seq_id < min_seq) { min_seq = seq_id; }
if (!seq.in_ubatch && target_cell >= 0) {
// never saw this seq_id before,
// but there's already a cell reserved for this token, use it
cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id);
} else if (seq.tail < 0) {
need_new_cell = true;
} else {
llama_rs_cell & tail = cache.rs.cells[seq.tail];
if (seq.in_ubatch) {
// this seq_id was already seen before in the batch
// assuming the tail cell already "has" this seq_id
tail.pos += 1;
target_cell = seq.tail;
} else {
// first time this sequence is seen,
// there's no reserved cell yet;
// if it's not the first sequence of the token, how could it even get here?
GGML_ASSERT(j == 0);
bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids;
if (has_same_seqs) {
// the tail cell of a seq_id is assumed to already be part of the seq_id,
// hence the skip of the first seq_id
for (int32_t k = 1; k < n_seq_ids; ++k) {
if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) {
has_same_seqs = false;
}
if (seq_id < min) {
min = seq_id;
}
}
// TODO: make the checkpoint interval configurable
if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) {
// a checkpoint should be saved
need_new_cell = true;
} else {
// re-use last tail
tail.pos += 1;
target_cell = seq.tail;
}
}
}
if (need_new_cell && target_cell < 0) {
const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq);
uint32_t cell_id = cache.rs.size;
bool looped_once = false;
while (true) {
if (cache.rs.head >= cache.rs.size) {
cache.rs.head = 0;
if (looped_once) {
// avoid infinite loop
// NOTE: this should not happen, but gracefully fail anyway
LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__);
return false;
}
looped_once = true;
}
cell_id = cache.rs.head;
llama_rs_cell & candidate = cache.rs.cells[cell_id];
if (candidate.is_empty()) { break; }
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
if (candidate.seq_nodes.size() > 1) {
// prune out the other seq_ids, because they diverge
// TODO(maybe): hande this in insert_seq_tail_to_cell_id
// (hopefully doesn't happen too often)
for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) {
if (node_iter->seq_id == seq_id) {
node_iter = std::next(node_iter);
} else {
node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter);
}
}
}
// re-use the tail cell to avoid not finding anything
candidate.pos += 1;
break;
}
if (candidate.tail_rc > 0) {
// skip tails of other sequences
cache.rs.head += 1;
continue;
}
if (candidate.seq_nodes.size() > 1) {
// shared prompts are not usually backtracked, so they can be pruned
cache.rs.clear_cell(candidate);
break;
}
// prune too-long sequences
llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id;
if (seq_id_to_prune == seq_id) {
// TODO: selectively skip some cells to keep older states
cache.rs.clear_cell(candidate);
break;
}
GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size());
auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune];
if (seq_to_prune.n_cells > min_cells_per_seq) {
cache.rs.clear_cell(candidate);
break;
}
cache.rs.head += 1;
}
if (cell_id < cache.rs.size) {
cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id);
target_cell = cell_id;
}
}
if (seq.tail >= 0) {
if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; }
if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; }
seq.in_ubatch = true;
}
// Assuming the tokens are in-order
if (batch.pos[i] != cache.rs.cells[seq_id].pos + 1) {
if (batch.pos[i] != cache.rs.cells[seq.tail].pos) {
// What should happen when the pos backtracks or skips a value?
// Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
__func__, batch.pos[i], cache.rs.cells[seq_id].pos, seq_id);
__func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id);
}
if (cache.rs.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
cache.rs.used += 1;
}
cache.rs.cells[seq_id].pos = batch.pos[i];
cache.rs.cells[seq_id].seq_nodes.insert(seq_id);
} else {
// too big seq_id
// TODO: would it be possible to resize the KV cache size instead?
// TODO: would it be possible to resize the rs cache size instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size);
return false;
}
}
cache.rs.head = target_cell + 1;
}
for (llama_seq_id i = min_seq; i <= max_seq; ++i) {
// make sure it's cleared for next time
cache.rs.seq_tails[i].in_ubatch = false;
}
// allow getting the range of used cells, from head to head + n
cache.rs.head = min;
cache.rs.n = max - min + 1;
cache.rs.head = min_cell;
cache.rs.n = max_cell - min_cell + 1;
// sanity check
if (max < min) {
if (max_seq < min_seq || max_cell < min_cell) {
return false;
}
}
@ -2799,6 +3132,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
return 0;
}
// find how many recurrent state cells are currently in use
static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) {
for (uint32_t i = cache.size; i > 0; --i) {
const llama_rs_cell & cell = cache.cells[i - 1];
@ -2829,12 +3163,15 @@ static void llama_cache_clear(struct llama_cache & cache) {
llama_rs_cell & rs_cell = cache.rs.cells[i];
rs_cell.pos = -1;
rs_cell.src = -1;
rs_cell.prev = -1;
rs_cell.tail_rc = 0;
rs_cell.seq_nodes.clear();
}
cache.rs.do_copy = false;
cache.rs.head = 0;
cache.rs.used = 0;
cache.rs.n_seqs = 0;
cache.rs.n_shared_tail_cells = 0;
cache.rs.seq_tails.clear();
cache.rs.seq_tails.resize(cache.rs.size);
}
@ -2846,8 +3183,8 @@ static llama_pos llama_cache_seq_rm(
llama_pos p0,
llama_pos p1) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (p0 < 0) { p0 = 0; }
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }
llama_pos n_past = p0;
@ -2863,7 +3200,9 @@ static llama_pos llama_cache_seq_rm(
for (uint32_t i = 0; i < cache.rs.size; ++i) {
llama_rs_cell & rs_cell = cache.rs.cells[i];
if (seq_id < 0 || rs_cell.has_seq_id(seq_id)) {
auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id);
if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) {
if (rs_cell.pos < p0) {
// move forward the new p0 further
if (rs_cell.pos >= new_p0) {
@ -2879,9 +3218,9 @@ static llama_pos llama_cache_seq_rm(
}
} else { // (rs_cell.pos >= p0 && rs_cell.pos < p1)
if (seq_id < 0) {
cache.rs.clear_cell(i);
cache.rs.clear_cell(rs_cell);
} else { // (rs_cell.has_seq_id(seq_id))
cache.rs.remove_seq_from_cell(i, seq_id);
cache.rs.remove_seq_node_from_cell(rs_cell, seq_node);
}
if (rs_cell.is_empty() && new_head == cache.rs.size) {
new_head = i;
@ -2943,11 +3282,12 @@ static llama_pos llama_cache_seq_cp(
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (p0 < 0) { p0 = 0; }
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }
// TODO: in practice this seems to be only used on whole sequences;
// should partial sequence copy be removed?
// should partial sequence copy support be removed?
// TODO: What if the destination sequence is not empty?
llama_pos n_past = 0;
@ -2973,11 +3313,11 @@ static llama_pos llama_cache_seq_cp(
if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) {
if (i == (uint32_t) src_tail) {
// need to be inserted in order, but there's only one
cache.rs.insert_seq_tail_to_cell(i, seq_id_dst);
cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst);
} else {
// keep only the tail cell of the source
// assuming a copy means no rollback will be attempted afterwards
cache.rs.remove_seq_from_cell(i, seq_id_src);
cache.rs.remove_seq_from_cell_id(i, seq_id_src);
if (new_head == cache.rs.size) {
new_head = i;
}
@ -3009,16 +3349,41 @@ static llama_pos llama_cache_seq_cp(
}
static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) {
if (cache.rs.size > 0) {
uint32_t new_head = cache.rs.size;
for (uint32_t i = 0; i < cache.rs.size; ++i) {
llama_rs_cell & rs_cell = cache.rs.cells[i];
if (!rs_cell.seq_nodes.empty()) {
for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) {
if (node_iter->seq_id != seq_id) {
node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter);
} else {
node_iter = std::next(node_iter);
}
}
if (new_head == cache.rs.size && rs_cell.is_empty()) {
new_head = i;
}
}
}
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.rs.size && new_head < cache.rs.head) {
cache.rs.head = new_head;
}
}
if (cache.kv.size > 0) {
uint32_t new_head = cache.kv.size;
for (uint32_t i = 0; i < cache.kv.size; ++i) {
llama_kv_cell & kv_cell = cache.kv.cells[i];
if (!kv_cell.has_seq_id(seq_id)) {
if (kv_cell.pos >= 0) cache.kv.used--;
if (kv_cell.pos >= 0) { cache.kv.used--; }
kv_cell.pos = -1;
kv_cell.seq_id.clear();
if (new_head == cache.kv.size) new_head = i;
if (new_head == cache.kv.size) { new_head = i; }
} else {
kv_cell.seq_id.clear();
kv_cell.seq_id.insert(seq_id);
@ -3052,13 +3417,12 @@ static llama_pos llama_cache_seq_add(
while (cell_id >= 0) {
GGML_ASSERT((uint32_t) cell_id < cache.rs.size);
llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
int32_t i = cell_id;
cell_id = rs_cell.prev;
if (rs_cell.pos >= p0 && rs_cell.pos < p1) {
rs_cell.pos += delta;
if (rs_cell.pos < 0) {
// NOTE: this affects the other sequences which share the cell
cache.rs.clear_cell(i);
cache.rs.clear_cell(rs_cell);
// TODO: update cache.rs.head
}
}
@ -6787,7 +7151,7 @@ struct llm_build_context {
}
struct ggml_tensor * build_inp_s_seq() {
lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_rs, n_tokens);
lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(lctx.inp_s_seq, "inp_s_seq", -1);
ggml_set_input(lctx.inp_s_seq);
return lctx.inp_s_seq;
@ -10482,26 +10846,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
int32_t * data = (int32_t *) lctx.inp_s_seq->data;
for (int j = 0; j < n_tokens; ++j) {
const int32_t n_seq = batch.n_seq_id[j];
GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
for (int i = 0; i < n_rs; ++i) {
if (i < n_seq) {
llama_seq_id seq_id = batch.seq_id[j][i];
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size());
const auto & seq = rs_self.seq_tails[seq_id];
// all sequences of this batch should already be initialized
GGML_ASSERT(seq.tail >= 0);
// ensure the relative cell id will be positive but not too big
GGML_ASSERT((uint32_t) seq.tail >= rs_self.head);
GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n);
data[j*n_rs + i] = seq.tail - rs_self.head;
} else {
data[j*n_rs + i] = -1;
}
}
data[i] = seq.tail - rs_self.head;
}
}
}
@ -14874,7 +15227,7 @@ struct llama_context * llama_new_context_with_model(
memory_size_s += ggml_nbytes(s);
}
LLAMA_LOG_INFO("%s: SSM state size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
LLAMA_LOG_INFO("%s: SSM state size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f),
ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f));
@ -14891,7 +15244,7 @@ struct llama_context * llama_new_context_with_model(
memory_size_v += ggml_nbytes(v);
}
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
LLAMA_LOG_INFO("%s: KV cache size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
@ -15458,7 +15811,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
const size_t s_kv = ctx->cache.kv.total_size();
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell;
// TODO: rs cache cells
// FIXME: rs cache cells
const size_t s_total = (
+ s_rng_size
@ -15606,14 +15959,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
}
}
// FIXME: copy rs cache
// copy kv cache
{
const auto & kv_self = ctx->kv_self;
const auto & kv_self = ctx->cache.kv;
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
@ -15637,17 +15991,6 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
if (kv_self.recurrent) {
// v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
tmp_buf.resize(v_size);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
continue;
}
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
@ -15753,7 +16096,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
}
}
// FIXME: set rs cache too
// FIXME: set rs cache
// set kv cache
{
const auto & kv_self = ctx->cache.kv;