llama : session saving and reloading for hybrid models

This commit is contained in:
Francis Couture-Harpin 2024-09-01 20:31:30 -04:00
parent bc320ef66d
commit fcb889cf7f
2 changed files with 391 additions and 134 deletions

View File

@ -38,10 +38,10 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 8
#define LLAMA_SESSION_VERSION 9
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 2
#define LLAMA_STATE_SEQ_VERSION 3
#ifdef __cplusplus
extern "C" {

View File

@ -19839,8 +19839,28 @@ struct llama_data_write {
}
}
void write_rs_cache_meta(const llama_rs_cache & rs_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) {
const auto & cell = rs_self.cells[i];
const llama_pos pos = cell.pos;
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0;
write(&pos, sizeof(pos));
write(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id) {
for (auto seq_node : cell.seq_nodes) {
write(&seq_node.seq_id, sizeof(seq_node.seq_id));
}
}
}
}
}
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_kv_cache & kv_self = ctx->cache.kv;
const struct llama_hparams & hparams = ctx->model.hparams;
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
@ -19849,12 +19869,10 @@ struct llama_data_write {
write(&v_trans, sizeof(v_trans));
write(&n_layer, sizeof(n_layer));
std::vector<uint8_t> tmp_buf;
// Iterate and write all the keys first, each row is a cell
// Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
// Write key type
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
@ -19874,7 +19892,7 @@ struct llama_data_write {
if (!kv_self.v_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
@ -19895,7 +19913,7 @@ struct llama_data_write {
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
@ -19922,43 +19940,151 @@ struct llama_data_write {
}
}
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
const struct llama_kv_cache & kv_self = ctx->kv_self;
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
void write_rs_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
const struct llama_rs_cache & rs_self = ctx->cache.rs;
const struct llama_hparams & hparams = ctx->model.hparams;
// Count the number of cells with the specified seq_id
// Find all the ranges of cells with this seq id (or all, when -1)
uint32_t cell_range_begin = kv_self.size;
for (uint32_t i = 0; i < kv_self.size; ++i) {
const auto & cell = kv_self.cells[i];
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
++cell_count;
if (cell_range_begin == kv_self.size) {
cell_range_begin = i;
const uint32_t n_layer = hparams.n_layer;
write(&n_layer, sizeof(n_layer));
// Iterate and write all recurrent states, each row is a cell
// Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_r = hparams.n_embd_r(il);
// Write type
const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type;
write(&r_type_i, sizeof(r_type_i));
// Write row size
const uint64_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r);
write(&r_size_row, sizeof(r_size_row));
// Read each range of cells of r_size length each and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * r_size_row;
write_tensor_data(rs_self.r_l[il], range.first * r_size_row, buf_size);
}
}
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_s = hparams.n_embd_s(il);
// Write type
const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type;
write(&s_type_i, sizeof(s_type_i));
// Write row size
const uint64_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s);
write(&s_size_row, sizeof(s_size_row));
// Read each range of cells of s_size length each and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * s_size_row;
write_tensor_data(rs_self.s_l[il], range.first * s_size_row, buf_size);
}
}
}
void write_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
const struct llama_kv_cache & kv_self = ctx->cache.kv;
const struct llama_rs_cache & rs_self = ctx->cache.rs;
std::vector<std::pair<uint32_t, uint32_t>> kv_cell_ranges; // ranges, from inclusive, to exclusive
std::vector<std::pair<uint32_t, uint32_t>> rs_cell_ranges; // ranges, from inclusive, to exclusive
uint32_t kv_cell_count = 0;
uint32_t rs_cell_count = 0;
// Transformer KV cache
{
// Count the number of cells with the specified seq_id
// Find all the ranges of cells with this seq id (or all, when -1)
uint32_t cell_range_begin = kv_self.size;
for (uint32_t i = 0; i < kv_self.size; ++i) {
const auto & cell = kv_self.cells[i];
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
++kv_cell_count;
if (cell_range_begin == kv_self.size) {
cell_range_begin = i;
}
} else {
if (cell_range_begin != kv_self.size) {
kv_cell_ranges.emplace_back(cell_range_begin, i);
cell_range_begin = kv_self.size;
}
}
} else {
if (cell_range_begin != kv_self.size) {
cell_ranges.emplace_back(cell_range_begin, i);
cell_range_begin = kv_self.size;
}
if (cell_range_begin != kv_self.size) {
kv_cell_ranges.emplace_back(cell_range_begin, kv_self.size);
}
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
uint32_t cell_count_check = 0;
for (const auto & range : kv_cell_ranges) {
cell_count_check += range.second - range.first;
}
GGML_ASSERT(kv_cell_count == cell_count_check);
}
// Recurrent state cache
if (seq_id == -1) {
// Find all the ranges of cells
uint32_t cell_range_begin = rs_self.size;
for (uint32_t i = 0; i < rs_self.size; ++i) {
const auto & cell = rs_self.cells[i];
if (!cell.is_empty()) {
++rs_cell_count;
if (cell_range_begin == rs_self.size) {
cell_range_begin = i;
}
} else {
if (cell_range_begin != rs_self.size) {
rs_cell_ranges.emplace_back(cell_range_begin, i);
cell_range_begin = rs_self.size;
}
}
}
if (cell_range_begin != rs_self.size) {
rs_cell_ranges.emplace_back(cell_range_begin, rs_self.size);
}
} else {
// Find the cell ranges of the specified seq_id
if ((size_t) seq_id < rs_self.seq_tails.size()) {
int32_t tail_cell_id = rs_self.seq_tails[seq_id].tail;
if (tail_cell_id >= 0) {
++rs_cell_count;
rs_cell_ranges.emplace_back(tail_cell_id, tail_cell_id + 1);
}
}
}
if (cell_range_begin != kv_self.size) {
cell_ranges.emplace_back(cell_range_begin, kv_self.size);
{
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
uint32_t cell_count_check = 0;
for (const auto & range : rs_cell_ranges) {
cell_count_check += range.second - range.first;
}
GGML_ASSERT(rs_cell_count == cell_count_check);
}
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
uint32_t cell_count_check = 0;
for (const auto & range : cell_ranges) {
cell_count_check += range.second - range.first;
write(&kv_cell_count, sizeof(kv_cell_count));
write(&rs_cell_count, sizeof(rs_cell_count));
if (seq_id == -1) {
// write metadata for both when the whole cache needs to be saved
write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id);
write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id);
} else if (kv_cell_count > 0) {
write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id);
} else {
write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id);
}
if (kv_cell_count > 0) {
write_kv_cache_data(ctx, kv_cell_ranges);
}
if (rs_cell_count > 0) {
write_rs_cache_data(ctx, rs_cell_ranges);
}
GGML_ASSERT(cell_count == cell_count_check);
write(&cell_count, sizeof(cell_count));
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
write_kv_cache_data(ctx, cell_ranges);
}
};
@ -20050,108 +20176,98 @@ struct llama_data_read {
}
}
bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
struct llama_kv_cache & kv_self = ctx->kv_self;
bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count) {
if (cell_count == 0) { return true; }
struct llama_past & cache = ctx->cache;
struct llama_kv_cache & kv_self = cache.kv;
if (dest_seq_id != -1) {
// single sequence
// whole KV cache restore
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
if (cell_count > kv_self.size) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
return false;
}
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
batch.n_tokens = cell_count;
batch.n_seq_tokens = cell_count;
batch.n_seqs = 1;
for (uint32_t i = 0; i < cell_count; ++i) {
llama_kv_cell & cell = kv_self.cells[i];
for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos;
uint32_t n_seq_id;
llama_pos pos;
uint32_t n_seq_id;
read_to(&pos, sizeof(pos));
read_to(&n_seq_id, sizeof(n_seq_id));
read_to(&pos, sizeof(pos));
read_to(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id != 0) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
cell.pos = pos;
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
read_to(&seq_id, sizeof(seq_id));
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
return false;
}
batch.pos[i] = pos;
}
batch.n_seq_id[0] = 1;
batch.seq_id[0] = &dest_seq_id;
if (!llama_kv_cache_find_slot(kv_self, batch)) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
cell.seq_id.insert(seq_id);
}
}
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
} else {
// whole KV cache restore
kv_self.head = 0;
kv_self.used = cell_count;
if (cell_count > kv_self.size) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
return false;
}
return true;
}
llama_kv_cache_clear(kv_self);
bool read_rs_cache_meta(struct llama_context * ctx, uint32_t cell_count) {
if (cell_count == 0) { return true; }
struct llama_past & cache = ctx->cache;
struct llama_rs_cache & rs_self = cache.rs;
for (uint32_t i = 0; i < cell_count; ++i) {
llama_kv_cell & cell = kv_self.cells[i];
// whole RS cache restore
llama_pos pos;
uint32_t n_seq_id;
if (cell_count > rs_self.size) {
LLAMA_LOG_ERROR("%s: not enough cells in rs cache\n", __func__);
return false;
}
read_to(&pos, sizeof(pos));
read_to(&n_seq_id, sizeof(n_seq_id));
for (uint32_t i = 0; i < cell_count; ++i) {
llama_rs_cell & cell = rs_self.cells[i];
cell.pos = pos;
llama_pos pos;
uint32_t n_seq_id;
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
read_to(&seq_id, sizeof(seq_id));
read_to(&pos, sizeof(pos));
read_to(&n_seq_id, sizeof(n_seq_id));
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
return false;
}
cell.pos = pos;
cell.src = i;
cell.seq_id.insert(seq_id);
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
read_to(&seq_id, sizeof(seq_id));
if (kv_self.recurrent) {
int32_t & tail = kv_self.cells[seq_id].tail;
if (tail != -1) {
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
return false;
}
tail = i;
}
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
return false;
}
}
kv_self.head = 0;
kv_self.used = cell_count;
}
cell.insert_node(seq_id);
if (kv_self.recurrent) {
for (uint32_t i = 0; i < cell_count; ++i) {
uint32_t cell_id = kv_self.head + i;
// make sure the recurrent states will keep their restored state
kv_self.cells[cell_id].src = cell_id;
}
}
rs_self.head = 0;
rs_self.used = cell_count;
rs_self.rebuild(/* debug */ false);
return true;
}
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
if (cell_count == 0) { return true; }
const struct llama_hparams & hparams = ctx->model.hparams;
struct llama_kv_cache & kv_self = ctx->kv_self;
struct llama_kv_cache & kv_self = ctx->cache.kv;
uint32_t v_trans;
uint32_t n_layer;
read_to(&v_trans, sizeof(v_trans));
@ -20172,7 +20288,7 @@ struct llama_data_read {
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
// Read type of key
int32_t k_type_i_ref;
@ -20192,15 +20308,13 @@ struct llama_data_read {
return false;
}
if (cell_count) {
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
}
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
}
if (!kv_self.v_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Read type of value
int32_t v_type_i_ref;
@ -20220,15 +20334,13 @@ struct llama_data_read {
return false;
}
if (cell_count) {
// Read and set the values for the whole cell range
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
}
// Read and set the values for the whole cell range
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
}
} else {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Read type of value
int32_t v_type_i_ref;
@ -20256,29 +20368,174 @@ struct llama_data_read {
return false;
}
if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
}
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
}
}
}
return true;
}
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
uint32_t cell_count;
read_to(&cell_count, sizeof(cell_count));
bool read_rs_cache_data(struct llama_context * ctx, uint32_t cell_count) {
if (cell_count == 0) { return true; }
const struct llama_hparams & hparams = ctx->model.hparams;
struct llama_rs_cache & rs_self = ctx->cache.rs;
uint32_t n_layer;
read_to(&n_layer, sizeof(n_layer));
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
if (n_layer != hparams.n_layer) {
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
return false;
}
if (cell_count > rs_self.size) {
LLAMA_LOG_ERROR("%s: not enough cells in rs cache to restore state (%u > %u)\n", __func__, cell_count, rs_self.size);
return false;
}
// For each layer, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_r = hparams.n_embd_r(il);
// Read type of key
int32_t r_type_i_ref;
read_to(&r_type_i_ref, sizeof(r_type_i_ref));
const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type;
if (r_type_i != r_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
return false;
}
// Read row size of key
uint64_t r_size_row_ref;
read_to(&r_size_row_ref, sizeof(r_size_row_ref));
const size_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r);
if (r_size_row != r_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
return false;
}
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(rs_self.r_l[il], read(cell_count * r_size_row), rs_self.head * r_size_row, cell_count * r_size_row);
}
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_s = hparams.n_embd_s(il);
// Read type of key
int32_t s_type_i_ref;
read_to(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type;
if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
return false;
}
// Read row size of key
uint64_t s_size_row_ref;
read_to(&s_size_row_ref, sizeof(s_size_row_ref));
const size_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s);
if (s_size_row != s_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
return false;
}
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(rs_self.s_l[il], read(cell_count * s_size_row), rs_self.head * s_size_row, cell_count * s_size_row);
}
return true;
}
bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) {
LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id);
return false;
}
// single sequence
llama_past & cache = ctx->cache;
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
batch.n_tokens = cell_count;
batch.n_seq_tokens = cell_count;
batch.n_seqs = 1;
for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos;
uint32_t n_seq_id;
read_to(&pos, sizeof(pos));
read_to(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id != 0) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
return false;
}
batch.pos[i] = pos;
}
batch.n_seq_id[0] = 1;
batch.seq_id[0] = &seq_id;
if (!llama_past_find_slot(cache, batch)) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
if (cache.kv.size > 0) {
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
GGML_ASSERT(cache.kv.head + cell_count <= cache.kv.size);
GGML_ASSERT(cache.kv.cells[cache.kv.head].pos == batch.pos[0]);
GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(cache.kv.cells[cache.kv.head].has_seq_id(seq_id));
GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].has_seq_id(seq_id));
}
if (cache.rs.size > 0) {
GGML_ASSERT(cache.rs.head + cache.rs.n <= cache.rs.size);
GGML_ASSERT(cache.rs.n == 1);
GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(cache.rs.cells[cache.rs.head].has_seq_id(seq_id));
GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].has_seq_id(seq_id));
// Prevent cells from being cleared
for (uint32_t i = cache.rs.head; i < cache.rs.head + cache.rs.n; ++i) {
cache.rs.cells[i].src = i;
}
}
return true;
}
void read_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
uint32_t kv_cell_count;
read_to(&kv_cell_count, sizeof(kv_cell_count));
uint32_t rs_cell_count;
read_to(&rs_cell_count, sizeof(rs_cell_count));
bool res = true;
if (seq_id == -1) {
llama_past_clear(ctx);
res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count);
} else {
llama_past_seq_rm(ctx, seq_id, -1, -1);
// Only a single recurrent cell at most,
// because otherwise the cells can be shuffled when a slot is allocated
if (rs_cell_count > 1) {
LLAMA_LOG_ERROR("%s: too many recurrent state cells for single-sequence session\n", __func__);
res = false;
}
res = res && read_cache_seq_meta(ctx, std::max(kv_cell_count, rs_cell_count), seq_id);
}
res = res && read_kv_cache_data(ctx, kv_cell_count) && read_rs_cache_data(ctx, rs_cell_count);
if (!res) {
if (seq_id == -1) {
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
} else {
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
llama_past_seq_rm(ctx, seq_id, -1, -1);
}
throw std::runtime_error("failed to restore kv cache");
}
@ -20433,7 +20690,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
data_ctx.write_logits(ctx);
data_ctx.write_embeddings(ctx);
data_ctx.write_kv_cache(ctx);
data_ctx.write_cache(ctx);
return data_ctx.get_size_written();
}
@ -20473,7 +20730,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
data_ctx.read_logits(ctx);
data_ctx.read_embeddings(ctx);
data_ctx.read_kv_cache(ctx);
data_ctx.read_cache(ctx);
return data_ctx.get_size_read();
}
@ -20569,7 +20826,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
llama_synchronize(ctx);
data_ctx.write_kv_cache(ctx, seq_id);
data_ctx.write_cache(ctx, seq_id);
return data_ctx.get_size_written();
}
@ -20592,7 +20849,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
llama_synchronize(ctx);
data_ctx.read_kv_cache(ctx, dest_seq_id);
data_ctx.read_cache(ctx, dest_seq_id);
return data_ctx.get_size_read();
}