From 0028010d01447c079f98bc33f06fca691fc99905 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Apr 2024 09:54:35 -0400 Subject: [PATCH] llama : state checkpoints for recurrent models --- ggml.c | 94 +++---- llama.cpp | 767 +++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 592 insertions(+), 269 deletions(-) diff --git a/ggml.c b/ggml.c index c9b0a6a0e..7a3f1b7a2 100644 --- a/ggml.c +++ b/ggml.c @@ -6335,19 +6335,18 @@ struct ggml_tensor * ggml_ssm_conv( GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(ggml_is_matrix(x)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_matrix(sq)); + GGML_ASSERT(ggml_is_vector(sq)); GGML_ASSERT(sq->type == GGML_TYPE_I32); const int64_t d_conv = c->ne[0]; const int64_t d_inner = c->ne[1]; const int64_t n_tokens = x->ne[1]; - const int64_t n_kv = s->ne[2]; + const int64_t n_rs = s->ne[2]; GGML_ASSERT( s->ne[0] == d_conv - 1); GGML_ASSERT( s->ne[1] == d_inner); GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_kv); - GGML_ASSERT(sq->ne[1] == n_tokens); + GGML_ASSERT(sq->ne[0] == n_tokens); bool is_node = false; @@ -6356,8 +6355,8 @@ struct ggml_tensor * ggml_ssm_conv( is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs} + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs)); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6410,7 +6409,7 @@ struct ggml_tensor * ggml_ssm_scan( is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} + // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs} struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; @@ -15087,9 +15086,9 @@ static void ggml_compute_forward_ssm_conv_f32( const int nc = src2->ne[0]; // d_conv const int nr = src0->ne[1]; // d_inner const int n_t = src1->ne[1]; // n_tokens - const int n_kv = src0->ne[2]; // max number of sequences in the batch + const int n_rs = src0->ne[2]; // max number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -15106,10 +15105,12 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { + const int32_t * sq = src3->data; // {n_tokens} + + if (n_rs > 1) { // multiple sequences means it's hard to know when it's the first time a state is read, // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_kv; ++i3) { + for (int i3 = 0; i3 < n_rs; ++i3) { float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); // can't use memcpy because of d_conv vs d_conv - 1 @@ -15123,19 +15124,19 @@ static void ggml_compute_forward_ssm_conv_f32( } for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} - float * s0; // {d_conv - 1, d_inner, n_kv} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} + int32_t sq_i = sq[i2]; + float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs} + float * s0; // {d_conv - 1, d_inner, n_rs} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} int ne0s0; - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + GGML_ASSERT(0 <= sq_i && sq_i < n_rs); // avoid needing to copy the state for the first token if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs} ne0s0 = src0->ne[0]; } else { // the source is the last (d_conv - 1) columns of the destination @@ -15153,18 +15154,6 @@ static void ggml_compute_forward_ssm_conv_f32( s[(nc - 1) + i1*nc] = x0[i1]; } - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } - // it seems a little faster when this is separate from the state shift for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product @@ -15216,7 +15205,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nc = src0->ne[0]; // d_state const int64_t nr = src0->ne[1]; // d_inner const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch + const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15225,6 +15214,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // required for the dot product between s and C, and when copying the states GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // required for per-sequence offsets for states @@ -15240,10 +15230,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { + const int32_t * sq = src6->data; // {n_tokens} + + if (n_rs > 1) { // it's hard to know if the source states have already been copied // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_kv; ++i3) { + for (int i3 = 0; i3 < n_rs; ++i3) { float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); memcpy(s, s0, nc*ir*sizeof(float)); @@ -15251,21 +15243,21 @@ static void ggml_compute_forward_ssm_scan_f32( } for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} + int32_t sq_i = sq[i2]; + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs} + float * s0; + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + GGML_ASSERT(0 <= sq_i && sq_i < n_rs); // avoid needing to copy the state for the first token if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs} } else { // otherwise the source is the same as the destination s0 = s; @@ -15288,18 +15280,6 @@ static void ggml_compute_forward_ssm_scan_f32( } y[i1] = sumf; } - - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } } } diff --git a/llama.cpp b/llama.cpp index 6dc310bf9..d561f80f6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2016,11 +2016,13 @@ struct llama_rs_seq_meta { // number of cells for which this seq_id is the first // (useful to know if cells in this sequence should be pruned) int32_t n_cells = 0; - // whether the tail is a cell part of multiple sequences - bool shared = false; + // changing the tail cell of a sequence can only be done at batch boundary, + // this guards against changing the cell when it shouldn't be; + // should be cleared when done finding a slot + bool in_ubatch = false; }; -// ring-buffer of cached recurrent state data +// ring-buffered tree of cached recurrent state data struct llama_rs_cache { bool do_copy = false; @@ -2032,8 +2034,10 @@ struct llama_rs_cache { uint32_t n = 0; // range of states used for the last slot // useful to know the minimum reserved cell count per seq_id - // only counts sequences with n_cells > 0 + // only counts sequences with n_cells > 0 AND which have a non-shared tail uint32_t n_seqs = 0; + // cells part of multiple sequences AND which have at least one tail + uint32_t n_shared_tail_cells = 0; // with state models, a cell can hold the state for more than one past token // TODO: it's probably not possible to always use contiguous cells @@ -2047,127 +2051,332 @@ struct llama_rs_cache { std::vector r_l; // rolling/shift states std::vector s_l; // ssm (recurrent) states - // returns whether or not a cell was freed - bool clear_cell(uint32_t i) { - if (i < size) { - llama_rs_cell & rs_cell = cells[i]; - if (!rs_cell.is_empty()) { - // update sequence tree links - bool first = true; - for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - // NOTE: if all next cells are the same cell, this should still work - cells[node.next_cell].prev = rs_cell.prev; + // TODO: maybe use a simpler data structure than a tree + + // Inefficient, but thorough verification and rebuilding of the rs cache + // from only the cells list with `pos` and seq_ids. + // Should not be called in a hot loop except when desperate and/or debugging. + bool rebuild(bool debug) { + bool was_valid = true; + // the source of truth is the cells list + // buffer sizes + if (size != cells.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", + __func__, cells.size(), size); + } + cells.resize(size); + was_valid = false; + } + if (size != seq_tails.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", + __func__, seq_tails.size(), size); + } + seq_tails.resize(size); + was_valid = false; + } + // cells consistency + uint32_t used_verif = 0; + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.seq_nodes.empty()) { + if (cell.pos >= 0) { + cell.pos = -1; + was_valid = false; + } + } + if (cell.pos < 0) { + if (cell.pos != -1) { + cell.pos = -1; + was_valid = false; + } + if (!cell.seq_nodes.empty()) { + cell.seq_nodes.clear(); + was_valid = false; + } + cell.src = -1; + if (cell.prev != -1) { + cell.prev = -1; + was_valid = false; + } + } else if (!debug) { + // Assuming the cache should be actually rebuilt when not debugging + cell.src = cell_id; + } + if (!cell.seq_nodes.empty()) { + used_verif += 1; + } + } + if (used != used_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid used cell count (%u instead of %u)\n", + __func__, used, used_verif); + } + used = used_verif; + was_valid = false; + } + // tail verification + std::vector> seq_cells; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + seq_cells.clear(); + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.has_seq_id(seq_id)) { + seq_cells.push_back({cell.pos, cell_id}); + } + } + // sort by pos and then by cell_id + std::sort(seq_cells.begin(), seq_cells.end()); + int32_t tail = seq_cells.empty() ? -1 : seq_cells[seq_cells.size() - 1].second; + if (tail != seq.tail) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.tail, tail); + } + seq.tail = tail; + was_valid = false; + } + int32_t prev = -1; + for (size_t i = 0; i < seq_cells.size(); ++i) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + if (cell.prev != prev) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, cell.prev, prev); } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - // update tail - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = new_tail.seq_nodes.size() > 1; + cell.prev = prev; + was_valid = false; + } + prev = cell_id; + } + int32_t n_cells = 0; + int32_t next = -1; + for (size_t i = seq_cells.size(); i-- > 0;) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + // assuming it's always found, because how else would it end up in the list of cells for this seq_id? + auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id); + if (seq_node == cell.seq_nodes.begin()) { + n_cells += 1; + } + if (seq_node->next_cell != next) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, seq_node->next_cell, next); + } + seq_node->next_cell = next; + was_valid = false; + } + next = cell_id; + } + if (seq.n_cells != n_cells) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.n_cells, n_cells); + } + seq.n_cells = n_cells; + } + // in_batch should only be true when in the process of finding a slot + if (seq.in_ubatch != false) { + if (debug) { + LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n", + __func__, seq_id); + } + seq.in_ubatch = false; + was_valid = false; + } + } + // tail_rc + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + uint32_t tail_rc = 0; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) { + tail_rc += 1; + } + } + if (cell.tail_rc != tail_rc) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n", + __func__, cell_id, cell.tail_rc, tail_rc); + } + cell.tail_rc = tail_rc; + was_valid = false; + } + } + // n_seqs + uint32_t n_seqs_verif = 0; + uint32_t n_shared_tail_cells_verif = 0; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + if (seq.tail >= 0) { + llama_rs_cell & tail_cell = cells[seq.tail]; + // NOTE: could also have checked if n_cells > 0 + if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) { + if (tail_cell.seq_nodes.size() > 1) { + n_shared_tail_cells_verif += 1; + } else { + n_seqs_verif += 1; + } + } + } + } + if (n_seqs != n_seqs_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n", + __func__, n_seqs, n_seqs_verif); + } + n_seqs = n_seqs_verif; + was_valid = false; + } + if (n_shared_tail_cells != n_shared_tail_cells_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n", + __func__, n_shared_tail_cells, n_shared_tail_cells_verif); + } + n_shared_tail_cells = n_shared_tail_cells_verif; + was_valid = false; + } + return was_valid; + } + + // returns whether or not a cell was freed + void clear_cell(llama_rs_cell & rs_cell) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + if (!rs_cell.is_empty()) { + // update sequence tree links + bool first = true; + for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: if all next cells are the same cell, this should still work + cells[node.next_cell].prev = rs_cell.prev; + } + // next_cell of the nodes of the previous cell + if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { + llama_rs_cell & prev_cell = cells[rs_cell.prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); + // assuming the previous node is always found + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); + prev_node->next_cell = node.next_cell; + if (node.is_tail()) { + prev_cell.tail_rc += 1; + } + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + // update tail + if (node.is_tail()) { + seq.tail = rs_cell.prev; + } + // cell counts + if (first) { + seq.n_cells -= 1; + if (rs_cell.tail_rc > 0 && seq.tail < 0) { + // last tail cell + if (rs_cell.seq_nodes.size() > 1) { + n_shared_tail_cells -= 1; } else { - seq.shared = false; - } - } - // cell counts - if (first) { - seq.n_cells -= 1; - if (seq.n_cells == 0) { - GGML_ASSERT(seq.tail < 0); n_seqs -= 1; } - first = false; + } + first = false; + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; + rs_cell.seq_nodes.clear(); + used -= 1; + } + } + + // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. + std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + // TODO: assert the iterator points inside the correct vector + if (node_iter != rs_cell.seq_nodes.end()) { + if (rs_cell.seq_nodes.size() == 1) { + clear_cell(rs_cell); + return rs_cell.seq_nodes.end(); + } + // else update tree + llama_rs_seq_node node = *node_iter; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { + llama_rs_cell & prev_cell = cells[rs_cell.prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); + // assuming the previous node is always found + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); + prev_node->next_cell = node.next_cell; + if (node.is_tail()) { + prev_cell.tail_rc += 1; + } + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail < 0 && rs_cell.tail_rc == 1) { + // assuming the previous cell of a shared cell is also shared, + // (no need to update the shared tail cells count elsewhere, then) + // this was a shared tail cell, but will no longer be a tail cell + n_shared_tail_cells -= 1; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } + if (node_iter == rs_cell.seq_nodes.begin()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = *(std::next(node_iter)); + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + // only the tail ref count from the other seq_ids are left in tail_rc + if (rs_cell.tail_rc > 0) { + // will become a non-shared cell + if (rs_cell.seq_nodes.size() == 2) { + n_seqs += 1; + } } } else { GGML_ASSERT(false && "invalid seq_id"); } } - rs_cell.pos = -1; - rs_cell.src = -1; - rs_cell.prev = -1; - rs_cell.tail_rc = 0; - rs_cell.seq_nodes.clear(); - used -= 1; - return true; + } else { + GGML_ASSERT(false && "invalid seq_id"); } + return rs_cell.seq_nodes.erase(node_iter); } - return false; + return node_iter; } - // TODO: maybe use a simpler data structure than a tree - // returns whether or not a cell was freed - bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { + // returns whether or not the seq_id was removed + bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < size) { llama_rs_cell & rs_cell = cells[i_cell]; auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once - if (node_iter != rs_cell.seq_nodes.end()) { - if (rs_cell.seq_nodes.size() == 1) { - return clear_cell(i_cell); - } - // else update tree - llama_rs_seq_node node = *node_iter; - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - cells[node.next_cell].prev = rs_cell.prev; - } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = cells[seq.tail].seq_nodes.size() > 1; - } else { - seq.shared = false; - } - GGML_ASSERT(rs_cell.tail_rc > 0); - rs_cell.tail_rc -= 1; - } - if (node_iter == rs_cell.seq_nodes.begin()) { - // this seq_id was the first in the list - seq.n_cells -= 1; - if (seq.n_cells == 0) { - n_seqs -= 1; - } - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = *(std::next(node_iter)); - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - if (next_seq.n_cells == 1) { - n_seqs += 1; - } - if (other_no_longer_shared) { - next_seq.shared = false; - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } else if (other_no_longer_shared) { - llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; - if ((uint32_t) first_node.seq_id < seq_tails.size()) { - seq_tails[first_node.seq_id].shared = false; - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - rs_cell.seq_nodes.erase(node_iter); - } + return node_iter != remove_seq_node_from_cell(rs_cell, node_iter); } return false; } - bool insert_seq_tail_to_cell(uint32_t i_cell, const llama_seq_id & id) { + bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < seq_tails.size()) { llama_rs_cell & rs_cell = cells[i_cell]; auto & seq = seq_tails[id]; @@ -2194,10 +2403,11 @@ struct llama_rs_cache { } prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; + rs_cell.prev = prev; } if (rs_cell.is_empty()) { - // only add after potential failures above - if (seq.n_cells == 0) { + // either the sequence didn't own any cells or had a shared tail cell + if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) { n_seqs += 1; } seq.n_cells += 1; @@ -2206,12 +2416,40 @@ struct llama_rs_cache { rs_cell.pos = 0; rs_cell.src = -1; } + used += 1; + } else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) { + // don't count shared-cell tails + // FIXME: make this saner + n_seqs -= 1; + n_shared_tail_cells += 1; + } else if (rs_cell.tail_rc == 0) { + // shared cell without a tail gets a tail; + // FIXME: don't prune, in case this is used in llama_cache_seq_cp + GGML_ASSERT(false); // make sure we don't get here by accident + // prune the other sequences out of this cell + // NOTE: have to inline the removal because the state tree is partially invalid + bool first = true; + for (auto & node : rs_cell.seq_nodes) { + GGML_ASSERT(node.seq_id != id); + GGML_ASSERT(node.next_cell >= 0); + // easy removal, none of the nodes are tails + llama_rs_cell & next_cell = cells[node.next_cell]; + next_cell.prev = rs_cell.prev; + if (first) { + auto & first_seq = seq_tails[node.seq_id]; + first_seq.n_cells -= 1; + first = false; + } + } + rs_cell.seq_nodes.clear(); + } else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) { + // this is correct as long as this isn't called when trying to find a slot + // TODO: find a way to assert this } // the target cell was not already a tail of this seq_id rs_cell.insert_node(id); // next_cell == -1 by default rs_cell.tail_rc += 1; seq.tail = i_cell; - seq.shared = rs_cell.seq_nodes.size() > 1; return true; } return false; @@ -2219,33 +2457,12 @@ struct llama_rs_cache { // each seq_id should have access to at least this many cells // (to use when pruning (to avoid over-pruning)) - // (but this over-prunes when the system prompt doesn't take lots of cells) - // Hmm. The system prompt does not need checkpoints... - size_t min_cells_per_seq() const { - return size / (n_seqs > 0 ? n_seqs : 1); - } - - // each seq_id can have at most this many cells - // (ignoring seqs which behave as a shared prompt) - // TODO: avoid recalculating system seq_ids - // (to use when pruning (to avoid over-pruning)) - // NOTE: this also limits the shared prompt to at most half the cells - // (but the shared prompt technically needs only one cell...) - // (IDEA: keep only one cell when `llama_kv_cache_seq_cp` is called on a sequence) - size_t max_cells_per_seq() const { - int32_t n_system_seqs = 0; - int32_t n_system_cells = 0; - for (size_t i = 0; i < seq_tails.size(); ++i) { - const auto & seq = seq_tails[i]; - if (seq.tail >= 0 && (size_t) seq.tail < size) { - if (seq.shared && seq.n_cells > 0) { - n_system_seqs += 1; - n_system_cells += seq.n_cells; - } - } + size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const { + uint32_t seqs = n_seqs; + if (new_seq.tail < 0 || new_seq.n_cells == 0) { + seqs += 1; } - int32_t n_other_seqs = n_seqs - n_system_seqs; - return (size - n_system_cells) / (n_other_seqs > 0 ? n_other_seqs : 1); + return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); } size_t total_size() const { @@ -2528,7 +2745,7 @@ struct llama_context { struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [n_rs] struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] - struct ggml_tensor * inp_s_seq; // I32 [n_rs, n_batch] + struct ggml_tensor * inp_s_seq; // I32 [n_batch] // control vectors struct llama_control_vector cvec; @@ -2657,7 +2874,7 @@ static bool llama_cache_init( return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -2678,54 +2895,170 @@ static bool llama_kv_cache_find_slot( if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. - // TODO: real ring-buffer of states - // TODO: state chekpoints (multiple cells per sequence) // TODO: find a way to always make the rs slot contiguous - // Okay, need to find a slot. Everything should fit assuming the biggest seq_id < rs_size - - - llama_seq_id min = cache.rs.size - 1; - llama_seq_id max = 0; + llama_seq_id min_seq = cache.rs.size - 1; + llama_seq_id max_seq = 0; + uint32_t min_cell = cache.rs.size - 1; + uint32_t max_cell = 0; for (uint32_t i = 0; i < n_tokens; ++i) { - for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { + int32_t target_cell = -1; // ensure all the sequences of a token get the same cell + int32_t n_seq_ids = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_ids; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; - // make sure it's a valid seq_id + bool need_new_cell = false; + // Everything should fit assuming the biggest seq_id < rs_size if ((uint32_t) seq_id < rs_size) { - if (seq_id > max) { - max = seq_id; + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + if (seq_id > max_seq) { max_seq = seq_id; } + if (seq_id < min_seq) { min_seq = seq_id; } + + if (!seq.in_ubatch && target_cell >= 0) { + // never saw this seq_id before, + // but there's already a cell reserved for this token, use it + cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); + } else if (seq.tail < 0) { + need_new_cell = true; + } else { + llama_rs_cell & tail = cache.rs.cells[seq.tail]; + if (seq.in_ubatch) { + // this seq_id was already seen before in the batch + // assuming the tail cell already "has" this seq_id + tail.pos += 1; + target_cell = seq.tail; + } else { + // first time this sequence is seen, + // there's no reserved cell yet; + // if it's not the first sequence of the token, how could it even get here? + GGML_ASSERT(j == 0); + + bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; + if (has_same_seqs) { + // the tail cell of a seq_id is assumed to already be part of the seq_id, + // hence the skip of the first seq_id + for (int32_t k = 1; k < n_seq_ids; ++k) { + if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { + has_same_seqs = false; + } + } + } + + // TODO: make the checkpoint interval configurable + if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { + // a checkpoint should be saved + need_new_cell = true; + } else { + // re-use last tail + tail.pos += 1; + target_cell = seq.tail; + } + } } - if (seq_id < min) { - min = seq_id; + + if (need_new_cell && target_cell < 0) { + const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + + uint32_t cell_id = cache.rs.size; + bool looped_once = false; + + while (true) { + if (cache.rs.head >= cache.rs.size) { + cache.rs.head = 0; + if (looped_once) { + // avoid infinite loop + // NOTE: this should not happen, but gracefully fail anyway + LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__); + return false; + } + looped_once = true; + } + cell_id = cache.rs.head; + llama_rs_cell & candidate = cache.rs.cells[cell_id]; + if (candidate.is_empty()) { break; } + if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + if (candidate.seq_nodes.size() > 1) { + // prune out the other seq_ids, because they diverge + // TODO(maybe): hande this in insert_seq_tail_to_cell_id + // (hopefully doesn't happen too often) + for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { + if (node_iter->seq_id == seq_id) { + node_iter = std::next(node_iter); + } else { + node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); + } + } + } + // re-use the tail cell to avoid not finding anything + candidate.pos += 1; + break; + } + if (candidate.tail_rc > 0) { + // skip tails of other sequences + cache.rs.head += 1; + continue; + } + if (candidate.seq_nodes.size() > 1) { + // shared prompts are not usually backtracked, so they can be pruned + cache.rs.clear_cell(candidate); + break; + } + + // prune too-long sequences + llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; + if (seq_id_to_prune == seq_id) { + // TODO: selectively skip some cells to keep older states + cache.rs.clear_cell(candidate); + break; + } + GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); + auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; + if (seq_to_prune.n_cells > min_cells_per_seq) { + cache.rs.clear_cell(candidate); + break; + } + cache.rs.head += 1; + } + if (cell_id < cache.rs.size) { + cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); + target_cell = cell_id; + } } + + if (seq.tail >= 0) { + if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } + if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } + seq.in_ubatch = true; + } + // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq_id].pos + 1) { + if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[seq_id].pos, seq_id); + __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); } - if (cache.rs.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.rs.used += 1; - } - cache.rs.cells[seq_id].pos = batch.pos[i]; - cache.rs.cells[seq_id].seq_nodes.insert(seq_id); } else { // too big seq_id - // TODO: would it be possible to resize the KV cache size instead? + // TODO: would it be possible to resize the rs cache size instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } + cache.rs.head = target_cell + 1; + } + + for (llama_seq_id i = min_seq; i <= max_seq; ++i) { + // make sure it's cleared for next time + cache.rs.seq_tails[i].in_ubatch = false; } // allow getting the range of used cells, from head to head + n - cache.rs.head = min; - cache.rs.n = max - min + 1; + cache.rs.head = min_cell; + cache.rs.n = max_cell - min_cell + 1; // sanity check - if (max < min) { + if (max_seq < min_seq || max_cell < min_cell) { return false; } } @@ -2799,6 +3132,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } +// find how many recurrent state cells are currently in use static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_rs_cell & cell = cache.cells[i - 1]; @@ -2829,12 +3163,15 @@ static void llama_cache_clear(struct llama_cache & cache) { llama_rs_cell & rs_cell = cache.rs.cells[i]; rs_cell.pos = -1; rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; rs_cell.seq_nodes.clear(); } cache.rs.do_copy = false; cache.rs.head = 0; cache.rs.used = 0; cache.rs.n_seqs = 0; + cache.rs.n_shared_tail_cells = 0; cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(cache.rs.size); } @@ -2846,8 +3183,8 @@ static llama_pos llama_cache_seq_rm( llama_pos p0, llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -2863,7 +3200,9 @@ static llama_pos llama_cache_seq_rm( for (uint32_t i = 0; i < cache.rs.size; ++i) { llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (seq_id < 0 || rs_cell.has_seq_id(seq_id)) { + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + + if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) { if (rs_cell.pos < p0) { // move forward the new p0 further if (rs_cell.pos >= new_p0) { @@ -2879,9 +3218,9 @@ static llama_pos llama_cache_seq_rm( } } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) if (seq_id < 0) { - cache.rs.clear_cell(i); + cache.rs.clear_cell(rs_cell); } else { // (rs_cell.has_seq_id(seq_id)) - cache.rs.remove_seq_from_cell(i, seq_id); + cache.rs.remove_seq_node_from_cell(rs_cell, seq_node); } if (rs_cell.is_empty() && new_head == cache.rs.size) { new_head = i; @@ -2943,11 +3282,12 @@ static llama_pos llama_cache_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } // TODO: in practice this seems to be only used on whole sequences; - // should partial sequence copy be removed? + // should partial sequence copy support be removed? + // TODO: What if the destination sequence is not empty? llama_pos n_past = 0; @@ -2973,11 +3313,11 @@ static llama_pos llama_cache_seq_cp( if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { if (i == (uint32_t) src_tail) { // need to be inserted in order, but there's only one - cache.rs.insert_seq_tail_to_cell(i, seq_id_dst); + cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst); } else { // keep only the tail cell of the source // assuming a copy means no rollback will be attempted afterwards - cache.rs.remove_seq_from_cell(i, seq_id_src); + cache.rs.remove_seq_from_cell_id(i, seq_id_src); if (new_head == cache.rs.size) { new_head = i; } @@ -3009,16 +3349,41 @@ static llama_pos llama_cache_seq_cp( } static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { + if (cache.rs.size > 0) { + uint32_t new_head = cache.rs.size; + + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (!rs_cell.seq_nodes.empty()) { + for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { + if (node_iter->seq_id != seq_id) { + node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + } else { + node_iter = std::next(node_iter); + } + } + if (new_head == cache.rs.size && rs_cell.is_empty()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } + } + if (cache.kv.size > 0) { uint32_t new_head = cache.kv.size; for (uint32_t i = 0; i < cache.kv.size; ++i) { llama_kv_cell & kv_cell = cache.kv.cells[i]; if (!kv_cell.has_seq_id(seq_id)) { - if (kv_cell.pos >= 0) cache.kv.used--; + if (kv_cell.pos >= 0) { cache.kv.used--; } kv_cell.pos = -1; kv_cell.seq_id.clear(); - if (new_head == cache.kv.size) new_head = i; + if (new_head == cache.kv.size) { new_head = i; } } else { kv_cell.seq_id.clear(); kv_cell.seq_id.insert(seq_id); @@ -3052,13 +3417,12 @@ static llama_pos llama_cache_seq_add( while (cell_id >= 0) { GGML_ASSERT((uint32_t) cell_id < cache.rs.size); llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - int32_t i = cell_id; cell_id = rs_cell.prev; if (rs_cell.pos >= p0 && rs_cell.pos < p1) { rs_cell.pos += delta; if (rs_cell.pos < 0) { // NOTE: this affects the other sequences which share the cell - cache.rs.clear_cell(i); + cache.rs.clear_cell(rs_cell); // TODO: update cache.rs.head } } @@ -6787,7 +7151,7 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_rs, n_tokens); + lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(lctx.inp_s_seq, "inp_s_seq", -1); ggml_set_input(lctx.inp_s_seq); return lctx.inp_s_seq; @@ -10482,26 +10846,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); int32_t * data = (int32_t *) lctx.inp_s_seq->data; - for (int j = 0; j < n_tokens; ++j) { - const int32_t n_seq = batch.n_seq_id[j]; - GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence + for (int i = 0; i < n_tokens; ++i) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); + const auto & seq = rs_self.seq_tails[seq_id]; + // ensure the relative cell id will be positive but not too big + GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); + GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); - for (int i = 0; i < n_rs; ++i) { - if (i < n_seq) { - llama_seq_id seq_id = batch.seq_id[j][i]; - GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); - const auto & seq = rs_self.seq_tails[seq_id]; - // all sequences of this batch should already be initialized - GGML_ASSERT(seq.tail >= 0); - // ensure the relative cell id will be positive but not too big - GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); - GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); - - data[j*n_rs + i] = seq.tail - rs_self.head; - } else { - data[j*n_rs + i] = -1; - } - } + data[i] = seq.tail - rs_self.head; } } } @@ -14874,7 +15227,7 @@ struct llama_context * llama_new_context_with_model( memory_size_s += ggml_nbytes(s); } - LLAMA_LOG_INFO("%s: SSM state size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: SSM state size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); @@ -14891,7 +15244,7 @@ struct llama_context * llama_new_context_with_model( memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV cache size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -15458,7 +15811,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_kv = ctx->cache.kv.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell; - // TODO: rs cache cells + // FIXME: rs cache cells const size_t s_total = ( + s_rng_size @@ -15606,14 +15959,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat } } + // FIXME: copy rs cache // copy kv cache { - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); @@ -15637,17 +15991,6 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { - // v is contiguous for recurrent models - // TODO: use other tensors for state models than k and v - const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); - - tmp_buf.resize(v_size); - ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size()); - data_ctx->write(tmp_buf.data(), tmp_buf.size()); - continue; - } - // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size); @@ -15753,7 +16096,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { } } - // FIXME: set rs cache too + // FIXME: set rs cache // set kv cache { const auto & kv_self = ctx->cache.kv;