mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-05 18:44:51 +01:00
llama : state checkpoints for recurrent models
This commit is contained in:
parent
8db1e4d45f
commit
0028010d01
74
ggml.c
74
ggml.c
@ -6335,19 +6335,18 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||
GGML_ASSERT(ggml_is_3d(s));
|
||||
GGML_ASSERT(ggml_is_matrix(x));
|
||||
GGML_ASSERT(ggml_is_matrix(c));
|
||||
GGML_ASSERT(ggml_is_matrix(sq));
|
||||
GGML_ASSERT(ggml_is_vector(sq));
|
||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
||||
|
||||
const int64_t d_conv = c->ne[0];
|
||||
const int64_t d_inner = c->ne[1];
|
||||
const int64_t n_tokens = x->ne[1];
|
||||
const int64_t n_kv = s->ne[2];
|
||||
const int64_t n_rs = s->ne[2];
|
||||
|
||||
GGML_ASSERT( s->ne[0] == d_conv - 1);
|
||||
GGML_ASSERT( s->ne[1] == d_inner);
|
||||
GGML_ASSERT( x->ne[0] == d_inner);
|
||||
GGML_ASSERT(sq->ne[0] == n_kv);
|
||||
GGML_ASSERT(sq->ne[1] == n_tokens);
|
||||
GGML_ASSERT(sq->ne[0] == n_tokens);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
@ -6356,8 +6355,8 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
|
||||
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs}
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs));
|
||||
|
||||
result->op = GGML_OP_SSM_CONV;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
@ -6410,7 +6409,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
|
||||
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs}
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||
|
||||
result->op = GGML_OP_SSM_SCAN;
|
||||
@ -15087,9 +15086,9 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
const int nc = src2->ne[0]; // d_conv
|
||||
const int nr = src0->ne[1]; // d_inner
|
||||
const int n_t = src1->ne[1]; // n_tokens
|
||||
const int n_kv = src0->ne[2]; // max number of sequences in the batch
|
||||
const int n_rs = src0->ne[2]; // max number of sequences in the batch
|
||||
|
||||
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
|
||||
GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
@ -15106,10 +15105,12 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
if (n_kv > 1) {
|
||||
const int32_t * sq = src3->data; // {n_tokens}
|
||||
|
||||
if (n_rs > 1) {
|
||||
// multiple sequences means it's hard to know when it's the first time a state is read,
|
||||
// so copy them all over to the destination, just to be sure.
|
||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
||||
for (int i3 = 0; i3 < n_rs; ++i3) {
|
||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
||||
// can't use memcpy because of d_conv vs d_conv - 1
|
||||
@ -15123,19 +15124,19 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
}
|
||||
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
||||
int32_t sq_i = sq[i2];
|
||||
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
|
||||
float * s0; // {d_conv - 1, d_inner, n_kv}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs}
|
||||
float * s0; // {d_conv - 1, d_inner, n_rs}
|
||||
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
||||
int ne0s0;
|
||||
|
||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
||||
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
|
||||
|
||||
// avoid needing to copy the state for the first token
|
||||
if (i2 == 0) {
|
||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
|
||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs}
|
||||
ne0s0 = src0->ne[0];
|
||||
} else {
|
||||
// the source is the last (d_conv - 1) columns of the destination
|
||||
@ -15153,18 +15154,6 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
s[(nc - 1) + i1*nc] = x0[i1];
|
||||
}
|
||||
|
||||
// handle copies when there are multiple output states
|
||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
||||
int32_t seq = sq[i3];
|
||||
if (0 <= seq && seq < n_kv) {
|
||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
||||
memcpy(s1, s, nc*ir*sizeof(float));
|
||||
} else {
|
||||
// stop at negative or too big seq_ids
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// it seems a little faster when this is separate from the state shift
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
// rowwise dot product
|
||||
@ -15216,7 +15205,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
const int64_t nc = src0->ne[0]; // d_state
|
||||
const int64_t nr = src0->ne[1]; // d_inner
|
||||
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
||||
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
|
||||
const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
@ -15225,6 +15214,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||
// required for the dot product between s and C, and when copying the states
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||
// required for per-sequence offsets for states
|
||||
@ -15240,10 +15230,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
|
||||
if (n_kv > 1) {
|
||||
const int32_t * sq = src6->data; // {n_tokens}
|
||||
|
||||
if (n_rs > 1) {
|
||||
// it's hard to know if the source states have already been copied
|
||||
// when there are multiple, so copy them already.
|
||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
||||
for (int i3 = 0; i3 < n_rs; ++i3) {
|
||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
||||
memcpy(s, s0, nc*ir*sizeof(float));
|
||||
@ -15251,9 +15243,9 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
}
|
||||
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
|
||||
int32_t sq_i = sq[i2];
|
||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs}
|
||||
float * s0;
|
||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
||||
@ -15261,11 +15253,11 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
||||
|
||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
||||
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);
|
||||
|
||||
// avoid needing to copy the state for the first token
|
||||
if (i2 == 0) {
|
||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
|
||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs}
|
||||
} else {
|
||||
// otherwise the source is the same as the destination
|
||||
s0 = s;
|
||||
@ -15288,18 +15280,6 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
}
|
||||
y[i1] = sumf;
|
||||
}
|
||||
|
||||
// handle copies when there are multiple output states
|
||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
||||
int32_t seq = sq[i3];
|
||||
if (0 <= seq && seq < n_kv) {
|
||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
||||
memcpy(s1, s, nc*ir*sizeof(float));
|
||||
} else {
|
||||
// stop at negative or too big seq_ids
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
647
llama.cpp
647
llama.cpp
@ -2016,11 +2016,13 @@ struct llama_rs_seq_meta {
|
||||
// number of cells for which this seq_id is the first
|
||||
// (useful to know if cells in this sequence should be pruned)
|
||||
int32_t n_cells = 0;
|
||||
// whether the tail is a cell part of multiple sequences
|
||||
bool shared = false;
|
||||
// changing the tail cell of a sequence can only be done at batch boundary,
|
||||
// this guards against changing the cell when it shouldn't be;
|
||||
// should be cleared when done finding a slot
|
||||
bool in_ubatch = false;
|
||||
};
|
||||
|
||||
// ring-buffer of cached recurrent state data
|
||||
// ring-buffered tree of cached recurrent state data
|
||||
struct llama_rs_cache {
|
||||
bool do_copy = false;
|
||||
|
||||
@ -2032,8 +2034,10 @@ struct llama_rs_cache {
|
||||
uint32_t n = 0; // range of states used for the last slot
|
||||
|
||||
// useful to know the minimum reserved cell count per seq_id
|
||||
// only counts sequences with n_cells > 0
|
||||
// only counts sequences with n_cells > 0 AND which have a non-shared tail
|
||||
uint32_t n_seqs = 0;
|
||||
// cells part of multiple sequences AND which have at least one tail
|
||||
uint32_t n_shared_tail_cells = 0;
|
||||
|
||||
// with state models, a cell can hold the state for more than one past token
|
||||
// TODO: it's probably not possible to always use contiguous cells
|
||||
@ -2047,10 +2051,204 @@ struct llama_rs_cache {
|
||||
std::vector<struct ggml_tensor *> r_l; // rolling/shift states
|
||||
std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
|
||||
|
||||
// TODO: maybe use a simpler data structure than a tree
|
||||
|
||||
// Inefficient, but thorough verification and rebuilding of the rs cache
|
||||
// from only the cells list with `pos` and seq_ids.
|
||||
// Should not be called in a hot loop except when desperate and/or debugging.
|
||||
bool rebuild(bool debug) {
|
||||
bool was_valid = true;
|
||||
// the source of truth is the cells list
|
||||
// buffer sizes
|
||||
if (size != cells.size()) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n",
|
||||
__func__, cells.size(), size);
|
||||
}
|
||||
cells.resize(size);
|
||||
was_valid = false;
|
||||
}
|
||||
if (size != seq_tails.size()) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n",
|
||||
__func__, seq_tails.size(), size);
|
||||
}
|
||||
seq_tails.resize(size);
|
||||
was_valid = false;
|
||||
}
|
||||
// cells consistency
|
||||
uint32_t used_verif = 0;
|
||||
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
|
||||
llama_rs_cell & cell = cells[cell_id];
|
||||
if (cell.seq_nodes.empty()) {
|
||||
if (cell.pos >= 0) {
|
||||
cell.pos = -1;
|
||||
was_valid = false;
|
||||
}
|
||||
}
|
||||
if (cell.pos < 0) {
|
||||
if (cell.pos != -1) {
|
||||
cell.pos = -1;
|
||||
was_valid = false;
|
||||
}
|
||||
if (!cell.seq_nodes.empty()) {
|
||||
cell.seq_nodes.clear();
|
||||
was_valid = false;
|
||||
}
|
||||
cell.src = -1;
|
||||
if (cell.prev != -1) {
|
||||
cell.prev = -1;
|
||||
was_valid = false;
|
||||
}
|
||||
} else if (!debug) {
|
||||
// Assuming the cache should be actually rebuilt when not debugging
|
||||
cell.src = cell_id;
|
||||
}
|
||||
if (!cell.seq_nodes.empty()) {
|
||||
used_verif += 1;
|
||||
}
|
||||
}
|
||||
if (used != used_verif) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: invalid used cell count (%u instead of %u)\n",
|
||||
__func__, used, used_verif);
|
||||
}
|
||||
used = used_verif;
|
||||
was_valid = false;
|
||||
}
|
||||
// tail verification
|
||||
std::vector<std::pair<llama_pos, uint32_t>> seq_cells;
|
||||
for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) {
|
||||
auto & seq = seq_tails[seq_id];
|
||||
seq_cells.clear();
|
||||
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
|
||||
llama_rs_cell & cell = cells[cell_id];
|
||||
if (cell.has_seq_id(seq_id)) {
|
||||
seq_cells.push_back({cell.pos, cell_id});
|
||||
}
|
||||
}
|
||||
// sort by pos and then by cell_id
|
||||
std::sort(seq_cells.begin(), seq_cells.end());
|
||||
int32_t tail = seq_cells.empty() ? -1 : seq_cells[seq_cells.size() - 1].second;
|
||||
if (tail != seq.tail) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: invalid tail for seq_id %d (%d instead of %d)\n",
|
||||
__func__, seq_id, seq.tail, tail);
|
||||
}
|
||||
seq.tail = tail;
|
||||
was_valid = false;
|
||||
}
|
||||
int32_t prev = -1;
|
||||
for (size_t i = 0; i < seq_cells.size(); ++i) {
|
||||
uint32_t cell_id = seq_cells[i].second;
|
||||
llama_rs_cell & cell = cells[cell_id];
|
||||
if (cell.prev != prev) {
|
||||
// TODO: relax the error when multiple cells have the same pos
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n",
|
||||
__func__, cell_id, cell.prev, prev);
|
||||
}
|
||||
cell.prev = prev;
|
||||
was_valid = false;
|
||||
}
|
||||
prev = cell_id;
|
||||
}
|
||||
int32_t n_cells = 0;
|
||||
int32_t next = -1;
|
||||
for (size_t i = seq_cells.size(); i-- > 0;) {
|
||||
uint32_t cell_id = seq_cells[i].second;
|
||||
llama_rs_cell & cell = cells[cell_id];
|
||||
// assuming it's always found, because how else would it end up in the list of cells for this seq_id?
|
||||
auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id);
|
||||
if (seq_node == cell.seq_nodes.begin()) {
|
||||
n_cells += 1;
|
||||
}
|
||||
if (seq_node->next_cell != next) {
|
||||
// TODO: relax the error when multiple cells have the same pos
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n",
|
||||
__func__, cell_id, seq_node->next_cell, next);
|
||||
}
|
||||
seq_node->next_cell = next;
|
||||
was_valid = false;
|
||||
}
|
||||
next = cell_id;
|
||||
}
|
||||
if (seq.n_cells != n_cells) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n",
|
||||
__func__, seq_id, seq.n_cells, n_cells);
|
||||
}
|
||||
seq.n_cells = n_cells;
|
||||
}
|
||||
// in_batch should only be true when in the process of finding a slot
|
||||
if (seq.in_ubatch != false) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n",
|
||||
__func__, seq_id);
|
||||
}
|
||||
seq.in_ubatch = false;
|
||||
was_valid = false;
|
||||
}
|
||||
}
|
||||
// tail_rc
|
||||
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
|
||||
llama_rs_cell & cell = cells[cell_id];
|
||||
uint32_t tail_rc = 0;
|
||||
for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) {
|
||||
auto & seq = seq_tails[seq_id];
|
||||
if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) {
|
||||
tail_rc += 1;
|
||||
}
|
||||
}
|
||||
if (cell.tail_rc != tail_rc) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n",
|
||||
__func__, cell_id, cell.tail_rc, tail_rc);
|
||||
}
|
||||
cell.tail_rc = tail_rc;
|
||||
was_valid = false;
|
||||
}
|
||||
}
|
||||
// n_seqs
|
||||
uint32_t n_seqs_verif = 0;
|
||||
uint32_t n_shared_tail_cells_verif = 0;
|
||||
for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) {
|
||||
auto & seq = seq_tails[seq_id];
|
||||
if (seq.tail >= 0) {
|
||||
llama_rs_cell & tail_cell = cells[seq.tail];
|
||||
// NOTE: could also have checked if n_cells > 0
|
||||
if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) {
|
||||
if (tail_cell.seq_nodes.size() > 1) {
|
||||
n_shared_tail_cells_verif += 1;
|
||||
} else {
|
||||
n_seqs_verif += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (n_seqs != n_seqs_verif) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n",
|
||||
__func__, n_seqs, n_seqs_verif);
|
||||
}
|
||||
n_seqs = n_seqs_verif;
|
||||
was_valid = false;
|
||||
}
|
||||
if (n_shared_tail_cells != n_shared_tail_cells_verif) {
|
||||
if (debug) {
|
||||
LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n",
|
||||
__func__, n_shared_tail_cells, n_shared_tail_cells_verif);
|
||||
}
|
||||
n_shared_tail_cells = n_shared_tail_cells_verif;
|
||||
was_valid = false;
|
||||
}
|
||||
return was_valid;
|
||||
}
|
||||
|
||||
// returns whether or not a cell was freed
|
||||
bool clear_cell(uint32_t i) {
|
||||
if (i < size) {
|
||||
llama_rs_cell & rs_cell = cells[i];
|
||||
void clear_cell(llama_rs_cell & rs_cell) {
|
||||
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
|
||||
if (!rs_cell.is_empty()) {
|
||||
// update sequence tree links
|
||||
bool first = true;
|
||||
@ -2059,27 +2257,34 @@ struct llama_rs_cache {
|
||||
// NOTE: if all next cells are the same cell, this should still work
|
||||
cells[node.next_cell].prev = rs_cell.prev;
|
||||
}
|
||||
// next_cell of the nodes of the previous cell
|
||||
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
|
||||
llama_rs_cell & prev_cell = cells[rs_cell.prev];
|
||||
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node);
|
||||
// assuming the previous node is always found
|
||||
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
|
||||
prev_node->next_cell = node.next_cell;
|
||||
if (node.is_tail()) {
|
||||
prev_cell.tail_rc += 1;
|
||||
}
|
||||
}
|
||||
if ((uint32_t) node.seq_id < seq_tails.size()) {
|
||||
auto & seq = seq_tails[node.seq_id];
|
||||
// update tail
|
||||
if (node.is_tail()) {
|
||||
seq.tail = rs_cell.prev;
|
||||
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
|
||||
llama_rs_cell & new_tail = cells[seq.tail];
|
||||
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
|
||||
new_tail.tail_rc += 1;
|
||||
seq.shared = new_tail.seq_nodes.size() > 1;
|
||||
} else {
|
||||
seq.shared = false;
|
||||
}
|
||||
}
|
||||
// cell counts
|
||||
if (first) {
|
||||
seq.n_cells -= 1;
|
||||
if (seq.n_cells == 0) {
|
||||
GGML_ASSERT(seq.tail < 0);
|
||||
if (rs_cell.tail_rc > 0 && seq.tail < 0) {
|
||||
// last tail cell
|
||||
if (rs_cell.seq_nodes.size() > 1) {
|
||||
n_shared_tail_cells -= 1;
|
||||
} else {
|
||||
n_seqs -= 1;
|
||||
}
|
||||
}
|
||||
first = false;
|
||||
}
|
||||
} else {
|
||||
@ -2092,39 +2297,42 @@ struct llama_rs_cache {
|
||||
rs_cell.tail_rc = 0;
|
||||
rs_cell.seq_nodes.clear();
|
||||
used -= 1;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: maybe use a simpler data structure than a tree
|
||||
// returns whether or not a cell was freed
|
||||
bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) {
|
||||
if (i_cell < size && (size_t) id < size) {
|
||||
llama_rs_cell & rs_cell = cells[i_cell];
|
||||
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
|
||||
// returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed.
|
||||
std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) {
|
||||
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
|
||||
// TODO: assert the iterator points inside the correct vector
|
||||
if (node_iter != rs_cell.seq_nodes.end()) {
|
||||
if (rs_cell.seq_nodes.size() == 1) {
|
||||
return clear_cell(i_cell);
|
||||
clear_cell(rs_cell);
|
||||
return rs_cell.seq_nodes.end();
|
||||
}
|
||||
// else update tree
|
||||
llama_rs_seq_node node = *node_iter;
|
||||
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
|
||||
cells[node.next_cell].prev = rs_cell.prev;
|
||||
}
|
||||
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
|
||||
llama_rs_cell & prev_cell = cells[rs_cell.prev];
|
||||
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node);
|
||||
// assuming the previous node is always found
|
||||
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
|
||||
prev_node->next_cell = node.next_cell;
|
||||
if (node.is_tail()) {
|
||||
prev_cell.tail_rc += 1;
|
||||
}
|
||||
}
|
||||
if ((uint32_t) node.seq_id < seq_tails.size()) {
|
||||
auto & seq = seq_tails[node.seq_id];
|
||||
bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2;
|
||||
if (node.is_tail()) {
|
||||
seq.tail = rs_cell.prev;
|
||||
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
|
||||
llama_rs_cell & new_tail = cells[seq.tail];
|
||||
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
|
||||
new_tail.tail_rc += 1;
|
||||
seq.shared = cells[seq.tail].seq_nodes.size() > 1;
|
||||
} else {
|
||||
seq.shared = false;
|
||||
if (seq.tail < 0 && rs_cell.tail_rc == 1) {
|
||||
// assuming the previous cell of a shared cell is also shared,
|
||||
// (no need to update the shared tail cells count elsewhere, then)
|
||||
// this was a shared tail cell, but will no longer be a tail cell
|
||||
n_shared_tail_cells -= 1;
|
||||
}
|
||||
GGML_ASSERT(rs_cell.tail_rc > 0);
|
||||
rs_cell.tail_rc -= 1;
|
||||
@ -2132,42 +2340,43 @@ struct llama_rs_cache {
|
||||
if (node_iter == rs_cell.seq_nodes.begin()) {
|
||||
// this seq_id was the first in the list
|
||||
seq.n_cells -= 1;
|
||||
if (seq.n_cells == 0) {
|
||||
n_seqs -= 1;
|
||||
}
|
||||
|
||||
// the next node is the new first one, so update its n_cells
|
||||
// (will never be out-of-bounds because the size is > 1)
|
||||
llama_rs_seq_node next_node = *(std::next(node_iter));
|
||||
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
|
||||
auto & next_seq = seq_tails[next_node.seq_id];
|
||||
next_seq.n_cells += 1;
|
||||
if (next_seq.n_cells == 1) {
|
||||
// only the tail ref count from the other seq_ids are left in tail_rc
|
||||
if (rs_cell.tail_rc > 0) {
|
||||
// will become a non-shared cell
|
||||
if (rs_cell.seq_nodes.size() == 2) {
|
||||
n_seqs += 1;
|
||||
}
|
||||
if (other_no_longer_shared) {
|
||||
next_seq.shared = false;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
} else if (other_no_longer_shared) {
|
||||
llama_rs_seq_node first_node = rs_cell.seq_nodes[0];
|
||||
if ((uint32_t) first_node.seq_id < seq_tails.size()) {
|
||||
seq_tails[first_node.seq_id].shared = false;
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid seq_id");
|
||||
}
|
||||
rs_cell.seq_nodes.erase(node_iter);
|
||||
return rs_cell.seq_nodes.erase(node_iter);
|
||||
}
|
||||
return node_iter;
|
||||
}
|
||||
|
||||
// returns whether or not the seq_id was removed
|
||||
bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) {
|
||||
if (i_cell < size && (size_t) id < size) {
|
||||
llama_rs_cell & rs_cell = cells[i_cell];
|
||||
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
|
||||
return node_iter != remove_seq_node_from_cell(rs_cell, node_iter);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool insert_seq_tail_to_cell(uint32_t i_cell, const llama_seq_id & id) {
|
||||
bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) {
|
||||
if (i_cell < size && (size_t) id < seq_tails.size()) {
|
||||
llama_rs_cell & rs_cell = cells[i_cell];
|
||||
auto & seq = seq_tails[id];
|
||||
@ -2194,10 +2403,11 @@ struct llama_rs_cache {
|
||||
}
|
||||
prev_cell.tail_rc -= 1;
|
||||
prev_node->next_cell = i_cell;
|
||||
rs_cell.prev = prev;
|
||||
}
|
||||
if (rs_cell.is_empty()) {
|
||||
// only add after potential failures above
|
||||
if (seq.n_cells == 0) {
|
||||
// either the sequence didn't own any cells or had a shared tail cell
|
||||
if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) {
|
||||
n_seqs += 1;
|
||||
}
|
||||
seq.n_cells += 1;
|
||||
@ -2206,12 +2416,40 @@ struct llama_rs_cache {
|
||||
rs_cell.pos = 0;
|
||||
rs_cell.src = -1;
|
||||
}
|
||||
used += 1;
|
||||
} else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) {
|
||||
// don't count shared-cell tails
|
||||
// FIXME: make this saner
|
||||
n_seqs -= 1;
|
||||
n_shared_tail_cells += 1;
|
||||
} else if (rs_cell.tail_rc == 0) {
|
||||
// shared cell without a tail gets a tail;
|
||||
// FIXME: don't prune, in case this is used in llama_cache_seq_cp
|
||||
GGML_ASSERT(false); // make sure we don't get here by accident
|
||||
// prune the other sequences out of this cell
|
||||
// NOTE: have to inline the removal because the state tree is partially invalid
|
||||
bool first = true;
|
||||
for (auto & node : rs_cell.seq_nodes) {
|
||||
GGML_ASSERT(node.seq_id != id);
|
||||
GGML_ASSERT(node.next_cell >= 0);
|
||||
// easy removal, none of the nodes are tails
|
||||
llama_rs_cell & next_cell = cells[node.next_cell];
|
||||
next_cell.prev = rs_cell.prev;
|
||||
if (first) {
|
||||
auto & first_seq = seq_tails[node.seq_id];
|
||||
first_seq.n_cells -= 1;
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
rs_cell.seq_nodes.clear();
|
||||
} else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) {
|
||||
// this is correct as long as this isn't called when trying to find a slot
|
||||
// TODO: find a way to assert this
|
||||
}
|
||||
// the target cell was not already a tail of this seq_id
|
||||
rs_cell.insert_node(id); // next_cell == -1 by default
|
||||
rs_cell.tail_rc += 1;
|
||||
seq.tail = i_cell;
|
||||
seq.shared = rs_cell.seq_nodes.size() > 1;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -2219,33 +2457,12 @@ struct llama_rs_cache {
|
||||
|
||||
// each seq_id should have access to at least this many cells
|
||||
// (to use when pruning (to avoid over-pruning))
|
||||
// (but this over-prunes when the system prompt doesn't take lots of cells)
|
||||
// Hmm. The system prompt does not need checkpoints...
|
||||
size_t min_cells_per_seq() const {
|
||||
return size / (n_seqs > 0 ? n_seqs : 1);
|
||||
size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const {
|
||||
uint32_t seqs = n_seqs;
|
||||
if (new_seq.tail < 0 || new_seq.n_cells == 0) {
|
||||
seqs += 1;
|
||||
}
|
||||
|
||||
// each seq_id can have at most this many cells
|
||||
// (ignoring seqs which behave as a shared prompt)
|
||||
// TODO: avoid recalculating system seq_ids
|
||||
// (to use when pruning (to avoid over-pruning))
|
||||
// NOTE: this also limits the shared prompt to at most half the cells
|
||||
// (but the shared prompt technically needs only one cell...)
|
||||
// (IDEA: keep only one cell when `llama_kv_cache_seq_cp` is called on a sequence)
|
||||
size_t max_cells_per_seq() const {
|
||||
int32_t n_system_seqs = 0;
|
||||
int32_t n_system_cells = 0;
|
||||
for (size_t i = 0; i < seq_tails.size(); ++i) {
|
||||
const auto & seq = seq_tails[i];
|
||||
if (seq.tail >= 0 && (size_t) seq.tail < size) {
|
||||
if (seq.shared && seq.n_cells > 0) {
|
||||
n_system_seqs += 1;
|
||||
n_system_cells += seq.n_cells;
|
||||
}
|
||||
}
|
||||
}
|
||||
int32_t n_other_seqs = n_seqs - n_system_seqs;
|
||||
return (size - n_system_cells) / (n_other_seqs > 0 ? n_other_seqs : 1);
|
||||
return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1);
|
||||
}
|
||||
|
||||
size_t total_size() const {
|
||||
@ -2528,7 +2745,7 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_s_copy; // I32 [n_rs]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_rs]
|
||||
struct ggml_tensor * inp_s_seq; // I32 [n_rs, n_batch]
|
||||
struct ggml_tensor * inp_s_seq; // I32 [n_batch]
|
||||
|
||||
// control vectors
|
||||
struct llama_control_vector cvec;
|
||||
@ -2657,7 +2874,7 @@ static bool llama_cache_init(
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||
LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||
cache.bufs.push_back(buf);
|
||||
}
|
||||
|
||||
@ -2678,54 +2895,170 @@ static bool llama_kv_cache_find_slot(
|
||||
if (rs_size > 0) {
|
||||
// For recurrent state architectures (like Mamba),
|
||||
// each cache cell can store the state for a whole sequence.
|
||||
// TODO: real ring-buffer of states
|
||||
// TODO: state chekpoints (multiple cells per sequence)
|
||||
// TODO: find a way to always make the rs slot contiguous
|
||||
|
||||
// Okay, need to find a slot. Everything should fit assuming the biggest seq_id < rs_size
|
||||
|
||||
|
||||
llama_seq_id min = cache.rs.size - 1;
|
||||
llama_seq_id max = 0;
|
||||
llama_seq_id min_seq = cache.rs.size - 1;
|
||||
llama_seq_id max_seq = 0;
|
||||
uint32_t min_cell = cache.rs.size - 1;
|
||||
uint32_t max_cell = 0;
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||
int32_t target_cell = -1; // ensure all the sequences of a token get the same cell
|
||||
int32_t n_seq_ids = batch.n_seq_id[i];
|
||||
for (int32_t j = 0; j < n_seq_ids; ++j) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||
// make sure it's a valid seq_id
|
||||
bool need_new_cell = false;
|
||||
// Everything should fit assuming the biggest seq_id < rs_size
|
||||
if ((uint32_t) seq_id < rs_size) {
|
||||
if (seq_id > max) {
|
||||
max = seq_id;
|
||||
llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id];
|
||||
if (seq_id > max_seq) { max_seq = seq_id; }
|
||||
if (seq_id < min_seq) { min_seq = seq_id; }
|
||||
|
||||
if (!seq.in_ubatch && target_cell >= 0) {
|
||||
// never saw this seq_id before,
|
||||
// but there's already a cell reserved for this token, use it
|
||||
cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id);
|
||||
} else if (seq.tail < 0) {
|
||||
need_new_cell = true;
|
||||
} else {
|
||||
llama_rs_cell & tail = cache.rs.cells[seq.tail];
|
||||
if (seq.in_ubatch) {
|
||||
// this seq_id was already seen before in the batch
|
||||
// assuming the tail cell already "has" this seq_id
|
||||
tail.pos += 1;
|
||||
target_cell = seq.tail;
|
||||
} else {
|
||||
// first time this sequence is seen,
|
||||
// there's no reserved cell yet;
|
||||
// if it's not the first sequence of the token, how could it even get here?
|
||||
GGML_ASSERT(j == 0);
|
||||
|
||||
bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids;
|
||||
if (has_same_seqs) {
|
||||
// the tail cell of a seq_id is assumed to already be part of the seq_id,
|
||||
// hence the skip of the first seq_id
|
||||
for (int32_t k = 1; k < n_seq_ids; ++k) {
|
||||
if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) {
|
||||
has_same_seqs = false;
|
||||
}
|
||||
if (seq_id < min) {
|
||||
min = seq_id;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make the checkpoint interval configurable
|
||||
if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) {
|
||||
// a checkpoint should be saved
|
||||
need_new_cell = true;
|
||||
} else {
|
||||
// re-use last tail
|
||||
tail.pos += 1;
|
||||
target_cell = seq.tail;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (need_new_cell && target_cell < 0) {
|
||||
const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq);
|
||||
|
||||
uint32_t cell_id = cache.rs.size;
|
||||
bool looped_once = false;
|
||||
|
||||
while (true) {
|
||||
if (cache.rs.head >= cache.rs.size) {
|
||||
cache.rs.head = 0;
|
||||
if (looped_once) {
|
||||
// avoid infinite loop
|
||||
// NOTE: this should not happen, but gracefully fail anyway
|
||||
LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__);
|
||||
return false;
|
||||
}
|
||||
looped_once = true;
|
||||
}
|
||||
cell_id = cache.rs.head;
|
||||
llama_rs_cell & candidate = cache.rs.cells[cell_id];
|
||||
if (candidate.is_empty()) { break; }
|
||||
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
|
||||
if (candidate.seq_nodes.size() > 1) {
|
||||
// prune out the other seq_ids, because they diverge
|
||||
// TODO(maybe): hande this in insert_seq_tail_to_cell_id
|
||||
// (hopefully doesn't happen too often)
|
||||
for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) {
|
||||
if (node_iter->seq_id == seq_id) {
|
||||
node_iter = std::next(node_iter);
|
||||
} else {
|
||||
node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
// re-use the tail cell to avoid not finding anything
|
||||
candidate.pos += 1;
|
||||
break;
|
||||
}
|
||||
if (candidate.tail_rc > 0) {
|
||||
// skip tails of other sequences
|
||||
cache.rs.head += 1;
|
||||
continue;
|
||||
}
|
||||
if (candidate.seq_nodes.size() > 1) {
|
||||
// shared prompts are not usually backtracked, so they can be pruned
|
||||
cache.rs.clear_cell(candidate);
|
||||
break;
|
||||
}
|
||||
|
||||
// prune too-long sequences
|
||||
llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id;
|
||||
if (seq_id_to_prune == seq_id) {
|
||||
// TODO: selectively skip some cells to keep older states
|
||||
cache.rs.clear_cell(candidate);
|
||||
break;
|
||||
}
|
||||
GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size());
|
||||
auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune];
|
||||
if (seq_to_prune.n_cells > min_cells_per_seq) {
|
||||
cache.rs.clear_cell(candidate);
|
||||
break;
|
||||
}
|
||||
cache.rs.head += 1;
|
||||
}
|
||||
if (cell_id < cache.rs.size) {
|
||||
cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id);
|
||||
target_cell = cell_id;
|
||||
}
|
||||
}
|
||||
|
||||
if (seq.tail >= 0) {
|
||||
if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; }
|
||||
if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; }
|
||||
seq.in_ubatch = true;
|
||||
}
|
||||
|
||||
// Assuming the tokens are in-order
|
||||
if (batch.pos[i] != cache.rs.cells[seq_id].pos + 1) {
|
||||
if (batch.pos[i] != cache.rs.cells[seq.tail].pos) {
|
||||
// What should happen when the pos backtracks or skips a value?
|
||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
|
||||
__func__, batch.pos[i], cache.rs.cells[seq_id].pos, seq_id);
|
||||
__func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id);
|
||||
}
|
||||
if (cache.rs.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
|
||||
cache.rs.used += 1;
|
||||
}
|
||||
cache.rs.cells[seq_id].pos = batch.pos[i];
|
||||
cache.rs.cells[seq_id].seq_nodes.insert(seq_id);
|
||||
} else {
|
||||
// too big seq_id
|
||||
// TODO: would it be possible to resize the KV cache size instead?
|
||||
// TODO: would it be possible to resize the rs cache size instead?
|
||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
cache.rs.head = target_cell + 1;
|
||||
}
|
||||
|
||||
for (llama_seq_id i = min_seq; i <= max_seq; ++i) {
|
||||
// make sure it's cleared for next time
|
||||
cache.rs.seq_tails[i].in_ubatch = false;
|
||||
}
|
||||
|
||||
// allow getting the range of used cells, from head to head + n
|
||||
cache.rs.head = min;
|
||||
cache.rs.n = max - min + 1;
|
||||
cache.rs.head = min_cell;
|
||||
cache.rs.n = max_cell - min_cell + 1;
|
||||
|
||||
// sanity check
|
||||
if (max < min) {
|
||||
if (max_seq < min_seq || max_cell < min_cell) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -2799,6 +3132,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// find how many recurrent state cells are currently in use
|
||||
static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) {
|
||||
for (uint32_t i = cache.size; i > 0; --i) {
|
||||
const llama_rs_cell & cell = cache.cells[i - 1];
|
||||
@ -2829,12 +3163,15 @@ static void llama_cache_clear(struct llama_cache & cache) {
|
||||
llama_rs_cell & rs_cell = cache.rs.cells[i];
|
||||
rs_cell.pos = -1;
|
||||
rs_cell.src = -1;
|
||||
rs_cell.prev = -1;
|
||||
rs_cell.tail_rc = 0;
|
||||
rs_cell.seq_nodes.clear();
|
||||
}
|
||||
cache.rs.do_copy = false;
|
||||
cache.rs.head = 0;
|
||||
cache.rs.used = 0;
|
||||
cache.rs.n_seqs = 0;
|
||||
cache.rs.n_shared_tail_cells = 0;
|
||||
cache.rs.seq_tails.clear();
|
||||
cache.rs.seq_tails.resize(cache.rs.size);
|
||||
}
|
||||
@ -2846,8 +3183,8 @@ static llama_pos llama_cache_seq_rm(
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
if (p0 < 0) { p0 = 0; }
|
||||
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }
|
||||
|
||||
llama_pos n_past = p0;
|
||||
|
||||
@ -2863,7 +3200,9 @@ static llama_pos llama_cache_seq_rm(
|
||||
|
||||
for (uint32_t i = 0; i < cache.rs.size; ++i) {
|
||||
llama_rs_cell & rs_cell = cache.rs.cells[i];
|
||||
if (seq_id < 0 || rs_cell.has_seq_id(seq_id)) {
|
||||
auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id);
|
||||
|
||||
if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) {
|
||||
if (rs_cell.pos < p0) {
|
||||
// move forward the new p0 further
|
||||
if (rs_cell.pos >= new_p0) {
|
||||
@ -2879,9 +3218,9 @@ static llama_pos llama_cache_seq_rm(
|
||||
}
|
||||
} else { // (rs_cell.pos >= p0 && rs_cell.pos < p1)
|
||||
if (seq_id < 0) {
|
||||
cache.rs.clear_cell(i);
|
||||
cache.rs.clear_cell(rs_cell);
|
||||
} else { // (rs_cell.has_seq_id(seq_id))
|
||||
cache.rs.remove_seq_from_cell(i, seq_id);
|
||||
cache.rs.remove_seq_node_from_cell(rs_cell, seq_node);
|
||||
}
|
||||
if (rs_cell.is_empty() && new_head == cache.rs.size) {
|
||||
new_head = i;
|
||||
@ -2943,11 +3282,12 @@ static llama_pos llama_cache_seq_cp(
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
if (p0 < 0) { p0 = 0; }
|
||||
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }
|
||||
|
||||
// TODO: in practice this seems to be only used on whole sequences;
|
||||
// should partial sequence copy be removed?
|
||||
// should partial sequence copy support be removed?
|
||||
// TODO: What if the destination sequence is not empty?
|
||||
|
||||
llama_pos n_past = 0;
|
||||
|
||||
@ -2973,11 +3313,11 @@ static llama_pos llama_cache_seq_cp(
|
||||
if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) {
|
||||
if (i == (uint32_t) src_tail) {
|
||||
// need to be inserted in order, but there's only one
|
||||
cache.rs.insert_seq_tail_to_cell(i, seq_id_dst);
|
||||
cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst);
|
||||
} else {
|
||||
// keep only the tail cell of the source
|
||||
// assuming a copy means no rollback will be attempted afterwards
|
||||
cache.rs.remove_seq_from_cell(i, seq_id_src);
|
||||
cache.rs.remove_seq_from_cell_id(i, seq_id_src);
|
||||
if (new_head == cache.rs.size) {
|
||||
new_head = i;
|
||||
}
|
||||
@ -3009,16 +3349,41 @@ static llama_pos llama_cache_seq_cp(
|
||||
}
|
||||
|
||||
static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) {
|
||||
if (cache.rs.size > 0) {
|
||||
uint32_t new_head = cache.rs.size;
|
||||
|
||||
for (uint32_t i = 0; i < cache.rs.size; ++i) {
|
||||
llama_rs_cell & rs_cell = cache.rs.cells[i];
|
||||
if (!rs_cell.seq_nodes.empty()) {
|
||||
for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) {
|
||||
if (node_iter->seq_id != seq_id) {
|
||||
node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter);
|
||||
} else {
|
||||
node_iter = std::next(node_iter);
|
||||
}
|
||||
}
|
||||
if (new_head == cache.rs.size && rs_cell.is_empty()) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != cache.rs.size && new_head < cache.rs.head) {
|
||||
cache.rs.head = new_head;
|
||||
}
|
||||
}
|
||||
|
||||
if (cache.kv.size > 0) {
|
||||
uint32_t new_head = cache.kv.size;
|
||||
|
||||
for (uint32_t i = 0; i < cache.kv.size; ++i) {
|
||||
llama_kv_cell & kv_cell = cache.kv.cells[i];
|
||||
if (!kv_cell.has_seq_id(seq_id)) {
|
||||
if (kv_cell.pos >= 0) cache.kv.used--;
|
||||
if (kv_cell.pos >= 0) { cache.kv.used--; }
|
||||
kv_cell.pos = -1;
|
||||
kv_cell.seq_id.clear();
|
||||
if (new_head == cache.kv.size) new_head = i;
|
||||
if (new_head == cache.kv.size) { new_head = i; }
|
||||
} else {
|
||||
kv_cell.seq_id.clear();
|
||||
kv_cell.seq_id.insert(seq_id);
|
||||
@ -3052,13 +3417,12 @@ static llama_pos llama_cache_seq_add(
|
||||
while (cell_id >= 0) {
|
||||
GGML_ASSERT((uint32_t) cell_id < cache.rs.size);
|
||||
llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
|
||||
int32_t i = cell_id;
|
||||
cell_id = rs_cell.prev;
|
||||
if (rs_cell.pos >= p0 && rs_cell.pos < p1) {
|
||||
rs_cell.pos += delta;
|
||||
if (rs_cell.pos < 0) {
|
||||
// NOTE: this affects the other sequences which share the cell
|
||||
cache.rs.clear_cell(i);
|
||||
cache.rs.clear_cell(rs_cell);
|
||||
// TODO: update cache.rs.head
|
||||
}
|
||||
}
|
||||
@ -6787,7 +7151,7 @@ struct llm_build_context {
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_inp_s_seq() {
|
||||
lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_rs, n_tokens);
|
||||
lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
cb(lctx.inp_s_seq, "inp_s_seq", -1);
|
||||
ggml_set_input(lctx.inp_s_seq);
|
||||
return lctx.inp_s_seq;
|
||||
@ -10482,26 +10846,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
|
||||
int32_t * data = (int32_t *) lctx.inp_s_seq->data;
|
||||
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const int32_t n_seq = batch.n_seq_id[j];
|
||||
GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
|
||||
|
||||
for (int i = 0; i < n_rs; ++i) {
|
||||
if (i < n_seq) {
|
||||
llama_seq_id seq_id = batch.seq_id[j][i];
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size());
|
||||
const auto & seq = rs_self.seq_tails[seq_id];
|
||||
// all sequences of this batch should already be initialized
|
||||
GGML_ASSERT(seq.tail >= 0);
|
||||
// ensure the relative cell id will be positive but not too big
|
||||
GGML_ASSERT((uint32_t) seq.tail >= rs_self.head);
|
||||
GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n);
|
||||
|
||||
data[j*n_rs + i] = seq.tail - rs_self.head;
|
||||
} else {
|
||||
data[j*n_rs + i] = -1;
|
||||
}
|
||||
}
|
||||
data[i] = seq.tail - rs_self.head;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -14874,7 +15227,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
memory_size_s += ggml_nbytes(s);
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: SSM state size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
||||
LLAMA_LOG_INFO("%s: SSM state size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f),
|
||||
ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f));
|
||||
@ -14891,7 +15244,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
memory_size_v += ggml_nbytes(v);
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
LLAMA_LOG_INFO("%s: KV cache size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
@ -15458,7 +15811,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||
const size_t s_kv = ctx->cache.kv.total_size();
|
||||
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
|
||||
const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell;
|
||||
// TODO: rs cache cells
|
||||
// FIXME: rs cache cells
|
||||
|
||||
const size_t s_total = (
|
||||
+ s_rng_size
|
||||
@ -15606,14 +15959,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: copy rs cache
|
||||
// copy kv cache
|
||||
{
|
||||
const auto & kv_self = ctx->kv_self;
|
||||
const auto & kv_self = ctx->cache.kv;
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
|
||||
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
|
||||
@ -15637,17 +15991,6 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
|
||||
data_ctx->write(tmp_buf.data(), tmp_buf.size());
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
// v is contiguous for recurrent models
|
||||
// TODO: use other tensors for state models than k and v
|
||||
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
|
||||
|
||||
tmp_buf.resize(v_size);
|
||||
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
|
||||
data_ctx->write(tmp_buf.data(), tmp_buf.size());
|
||||
continue;
|
||||
}
|
||||
|
||||
// v is not contiguous, copy row by row
|
||||
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
|
||||
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
|
||||
@ -15753,7 +16096,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: set rs cache too
|
||||
// FIXME: set rs cache
|
||||
// set kv cache
|
||||
{
|
||||
const auto & kv_self = ctx->cache.kv;
|
||||
|
Loading…
Reference in New Issue
Block a user