mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 13:52:22 +01:00
llama : session saving and reloading for hybrid models
This commit is contained in:
parent
bc320ef66d
commit
fcb889cf7f
@ -38,10 +38,10 @@
|
||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 8
|
||||
#define LLAMA_SESSION_VERSION 9
|
||||
|
||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||
#define LLAMA_STATE_SEQ_VERSION 2
|
||||
#define LLAMA_STATE_SEQ_VERSION 3
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
521
src/llama.cpp
521
src/llama.cpp
@ -19839,8 +19839,28 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
void write_rs_cache_meta(const llama_rs_cache & rs_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
|
||||
|
||||
for (const auto & range : cell_ranges) {
|
||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||
const auto & cell = rs_self.cells[i];
|
||||
const llama_pos pos = cell.pos;
|
||||
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0;
|
||||
|
||||
write(&pos, sizeof(pos));
|
||||
write(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id) {
|
||||
for (auto seq_node : cell.seq_nodes) {
|
||||
write(&seq_node.seq_id, sizeof(seq_node.seq_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
const struct llama_kv_cache & kv_self = ctx->cache.kv;
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
|
||||
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
|
||||
@ -19849,12 +19869,10 @@ struct llama_data_write {
|
||||
write(&v_trans, sizeof(v_trans));
|
||||
write(&n_layer, sizeof(n_layer));
|
||||
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
|
||||
// Iterate and write all the keys first, each row is a cell
|
||||
// Get whole range at a time
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
||||
@ -19874,7 +19892,7 @@ struct llama_data_write {
|
||||
|
||||
if (!kv_self.v_trans) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||
@ -19895,7 +19913,7 @@ struct llama_data_write {
|
||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||
@ -19922,43 +19940,151 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t cell_count = 0;
|
||||
void write_rs_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
|
||||
const struct llama_rs_cache & rs_self = ctx->cache.rs;
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
|
||||
// Count the number of cells with the specified seq_id
|
||||
// Find all the ranges of cells with this seq id (or all, when -1)
|
||||
uint32_t cell_range_begin = kv_self.size;
|
||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||
const auto & cell = kv_self.cells[i];
|
||||
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
||||
++cell_count;
|
||||
if (cell_range_begin == kv_self.size) {
|
||||
cell_range_begin = i;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
write(&n_layer, sizeof(n_layer));
|
||||
|
||||
// Iterate and write all recurrent states, each row is a cell
|
||||
// Get whole range at a time
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_r = hparams.n_embd_r(il);
|
||||
|
||||
// Write type
|
||||
const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type;
|
||||
write(&r_type_i, sizeof(r_type_i));
|
||||
|
||||
// Write row size
|
||||
const uint64_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r);
|
||||
write(&r_size_row, sizeof(r_size_row));
|
||||
|
||||
// Read each range of cells of r_size length each and write out
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * r_size_row;
|
||||
write_tensor_data(rs_self.r_l[il], range.first * r_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_s = hparams.n_embd_s(il);
|
||||
|
||||
// Write type
|
||||
const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type;
|
||||
write(&s_type_i, sizeof(s_type_i));
|
||||
|
||||
// Write row size
|
||||
const uint64_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s);
|
||||
write(&s_size_row, sizeof(s_size_row));
|
||||
|
||||
// Read each range of cells of s_size length each and write out
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * s_size_row;
|
||||
write_tensor_data(rs_self.s_l[il], range.first * s_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void write_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||
const struct llama_kv_cache & kv_self = ctx->cache.kv;
|
||||
const struct llama_rs_cache & rs_self = ctx->cache.rs;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> kv_cell_ranges; // ranges, from inclusive, to exclusive
|
||||
std::vector<std::pair<uint32_t, uint32_t>> rs_cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t kv_cell_count = 0;
|
||||
uint32_t rs_cell_count = 0;
|
||||
// Transformer KV cache
|
||||
{
|
||||
// Count the number of cells with the specified seq_id
|
||||
// Find all the ranges of cells with this seq id (or all, when -1)
|
||||
uint32_t cell_range_begin = kv_self.size;
|
||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||
const auto & cell = kv_self.cells[i];
|
||||
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
||||
++kv_cell_count;
|
||||
if (cell_range_begin == kv_self.size) {
|
||||
cell_range_begin = i;
|
||||
}
|
||||
} else {
|
||||
if (cell_range_begin != kv_self.size) {
|
||||
kv_cell_ranges.emplace_back(cell_range_begin, i);
|
||||
cell_range_begin = kv_self.size;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (cell_range_begin != kv_self.size) {
|
||||
cell_ranges.emplace_back(cell_range_begin, i);
|
||||
cell_range_begin = kv_self.size;
|
||||
}
|
||||
if (cell_range_begin != kv_self.size) {
|
||||
kv_cell_ranges.emplace_back(cell_range_begin, kv_self.size);
|
||||
}
|
||||
|
||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||
uint32_t cell_count_check = 0;
|
||||
for (const auto & range : kv_cell_ranges) {
|
||||
cell_count_check += range.second - range.first;
|
||||
}
|
||||
GGML_ASSERT(kv_cell_count == cell_count_check);
|
||||
}
|
||||
// Recurrent state cache
|
||||
if (seq_id == -1) {
|
||||
// Find all the ranges of cells
|
||||
uint32_t cell_range_begin = rs_self.size;
|
||||
for (uint32_t i = 0; i < rs_self.size; ++i) {
|
||||
const auto & cell = rs_self.cells[i];
|
||||
if (!cell.is_empty()) {
|
||||
++rs_cell_count;
|
||||
if (cell_range_begin == rs_self.size) {
|
||||
cell_range_begin = i;
|
||||
}
|
||||
} else {
|
||||
if (cell_range_begin != rs_self.size) {
|
||||
rs_cell_ranges.emplace_back(cell_range_begin, i);
|
||||
cell_range_begin = rs_self.size;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cell_range_begin != rs_self.size) {
|
||||
rs_cell_ranges.emplace_back(cell_range_begin, rs_self.size);
|
||||
}
|
||||
|
||||
} else {
|
||||
// Find the cell ranges of the specified seq_id
|
||||
if ((size_t) seq_id < rs_self.seq_tails.size()) {
|
||||
int32_t tail_cell_id = rs_self.seq_tails[seq_id].tail;
|
||||
if (tail_cell_id >= 0) {
|
||||
++rs_cell_count;
|
||||
rs_cell_ranges.emplace_back(tail_cell_id, tail_cell_id + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cell_range_begin != kv_self.size) {
|
||||
cell_ranges.emplace_back(cell_range_begin, kv_self.size);
|
||||
|
||||
{
|
||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||
uint32_t cell_count_check = 0;
|
||||
for (const auto & range : rs_cell_ranges) {
|
||||
cell_count_check += range.second - range.first;
|
||||
}
|
||||
GGML_ASSERT(rs_cell_count == cell_count_check);
|
||||
}
|
||||
|
||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||
uint32_t cell_count_check = 0;
|
||||
for (const auto & range : cell_ranges) {
|
||||
cell_count_check += range.second - range.first;
|
||||
write(&kv_cell_count, sizeof(kv_cell_count));
|
||||
write(&rs_cell_count, sizeof(rs_cell_count));
|
||||
|
||||
if (seq_id == -1) {
|
||||
// write metadata for both when the whole cache needs to be saved
|
||||
write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id);
|
||||
write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id);
|
||||
} else if (kv_cell_count > 0) {
|
||||
write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id);
|
||||
} else {
|
||||
write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id);
|
||||
}
|
||||
if (kv_cell_count > 0) {
|
||||
write_kv_cache_data(ctx, kv_cell_ranges);
|
||||
}
|
||||
if (rs_cell_count > 0) {
|
||||
write_rs_cache_data(ctx, rs_cell_ranges);
|
||||
}
|
||||
GGML_ASSERT(cell_count == cell_count_check);
|
||||
|
||||
write(&cell_count, sizeof(cell_count));
|
||||
|
||||
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
|
||||
write_kv_cache_data(ctx, cell_ranges);
|
||||
}
|
||||
};
|
||||
|
||||
@ -20050,108 +20176,98 @@ struct llama_data_read {
|
||||
}
|
||||
}
|
||||
|
||||
bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
|
||||
struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count) {
|
||||
if (cell_count == 0) { return true; }
|
||||
struct llama_past & cache = ctx->cache;
|
||||
struct llama_kv_cache & kv_self = cache.kv;
|
||||
|
||||
if (dest_seq_id != -1) {
|
||||
// single sequence
|
||||
// whole KV cache restore
|
||||
|
||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||
if (cell_count > kv_self.size) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
batch.n_tokens = cell_count;
|
||||
batch.n_seq_tokens = cell_count;
|
||||
batch.n_seqs = 1;
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_kv_cell & cell = kv_self.cells[i];
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
read_to(&pos, sizeof(pos));
|
||||
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
read_to(&pos, sizeof(pos));
|
||||
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 0) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
cell.pos = pos;
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
read_to(&seq_id, sizeof(seq_id));
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||
return false;
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
}
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.seq_id[0] = &dest_seq_id;
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
cell.seq_id.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
||||
} else {
|
||||
// whole KV cache restore
|
||||
kv_self.head = 0;
|
||||
kv_self.used = cell_count;
|
||||
|
||||
if (cell_count > kv_self.size) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_kv_cache_clear(kv_self);
|
||||
bool read_rs_cache_meta(struct llama_context * ctx, uint32_t cell_count) {
|
||||
if (cell_count == 0) { return true; }
|
||||
struct llama_past & cache = ctx->cache;
|
||||
struct llama_rs_cache & rs_self = cache.rs;
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_kv_cell & cell = kv_self.cells[i];
|
||||
// whole RS cache restore
|
||||
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
if (cell_count > rs_self.size) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in rs cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
read_to(&pos, sizeof(pos));
|
||||
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_rs_cell & cell = rs_self.cells[i];
|
||||
|
||||
cell.pos = pos;
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
read_to(&seq_id, sizeof(seq_id));
|
||||
read_to(&pos, sizeof(pos));
|
||||
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||
return false;
|
||||
}
|
||||
cell.pos = pos;
|
||||
cell.src = i;
|
||||
|
||||
cell.seq_id.insert(seq_id);
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
read_to(&seq_id, sizeof(seq_id));
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
int32_t & tail = kv_self.cells[seq_id].tail;
|
||||
if (tail != -1) {
|
||||
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
|
||||
return false;
|
||||
}
|
||||
tail = i;
|
||||
}
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
kv_self.head = 0;
|
||||
kv_self.used = cell_count;
|
||||
}
|
||||
cell.insert_node(seq_id);
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
uint32_t cell_id = kv_self.head + i;
|
||||
// make sure the recurrent states will keep their restored state
|
||||
kv_self.cells[cell_id].src = cell_id;
|
||||
}
|
||||
}
|
||||
|
||||
rs_self.head = 0;
|
||||
rs_self.used = cell_count;
|
||||
|
||||
rs_self.rebuild(/* debug */ false);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
|
||||
if (cell_count == 0) { return true; }
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
struct llama_kv_cache & kv_self = ctx->cache.kv;
|
||||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
read_to(&v_trans, sizeof(v_trans));
|
||||
@ -20172,7 +20288,7 @@ struct llama_data_read {
|
||||
|
||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
@ -20192,15 +20308,13 @@ struct llama_data_read {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
|
||||
}
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
|
||||
}
|
||||
|
||||
if (!kv_self.v_trans) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@ -20220,15 +20334,13 @@ struct llama_data_read {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
|
||||
}
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
|
||||
}
|
||||
} else {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@ -20256,29 +20368,174 @@ struct llama_data_read {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||
uint32_t cell_count;
|
||||
read_to(&cell_count, sizeof(cell_count));
|
||||
bool read_rs_cache_data(struct llama_context * ctx, uint32_t cell_count) {
|
||||
if (cell_count == 0) { return true; }
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
struct llama_rs_cache & rs_self = ctx->cache.rs;
|
||||
uint32_t n_layer;
|
||||
read_to(&n_layer, sizeof(n_layer));
|
||||
|
||||
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
|
||||
if (n_layer != hparams.n_layer) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
||||
return false;
|
||||
}
|
||||
if (cell_count > rs_self.size) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in rs cache to restore state (%u > %u)\n", __func__, cell_count, rs_self.size);
|
||||
return false;
|
||||
}
|
||||
|
||||
// For each layer, one row is one cell, read as one contiguous block
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_r = hparams.n_embd_r(il);
|
||||
|
||||
// Read type of key
|
||||
int32_t r_type_i_ref;
|
||||
read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type;
|
||||
if (r_type_i != r_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read row size of key
|
||||
uint64_t r_size_row_ref;
|
||||
read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
const size_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r);
|
||||
if (r_size_row != r_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(rs_self.r_l[il], read(cell_count * r_size_row), rs_self.head * r_size_row, cell_count * r_size_row);
|
||||
}
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_s = hparams.n_embd_s(il);
|
||||
|
||||
// Read type of key
|
||||
int32_t s_type_i_ref;
|
||||
read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type;
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read row size of key
|
||||
uint64_t s_size_row_ref;
|
||||
read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
const size_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s);
|
||||
if (s_size_row != s_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(rs_self.s_l[il], read(cell_count * s_size_row), rs_self.head * s_size_row, cell_count * s_size_row);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
|
||||
|
||||
if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) {
|
||||
LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id);
|
||||
return false;
|
||||
}
|
||||
|
||||
// single sequence
|
||||
|
||||
llama_past & cache = ctx->cache;
|
||||
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
batch.n_tokens = cell_count;
|
||||
batch.n_seq_tokens = cell_count;
|
||||
batch.n_seqs = 1;
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
read_to(&pos, sizeof(pos));
|
||||
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 0) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
}
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.seq_id[0] = &seq_id;
|
||||
if (!llama_past_find_slot(cache, batch)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cache.kv.size > 0) {
|
||||
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(cache.kv.head + cell_count <= cache.kv.size);
|
||||
GGML_ASSERT(cache.kv.cells[cache.kv.head].pos == batch.pos[0]);
|
||||
GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cache.kv.cells[cache.kv.head].has_seq_id(seq_id));
|
||||
GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].has_seq_id(seq_id));
|
||||
}
|
||||
if (cache.rs.size > 0) {
|
||||
GGML_ASSERT(cache.rs.head + cache.rs.n <= cache.rs.size);
|
||||
GGML_ASSERT(cache.rs.n == 1);
|
||||
GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cache.rs.cells[cache.rs.head].has_seq_id(seq_id));
|
||||
GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].has_seq_id(seq_id));
|
||||
// Prevent cells from being cleared
|
||||
for (uint32_t i = cache.rs.head; i < cache.rs.head + cache.rs.n; ++i) {
|
||||
cache.rs.cells[i].src = i;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void read_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||
uint32_t kv_cell_count;
|
||||
read_to(&kv_cell_count, sizeof(kv_cell_count));
|
||||
uint32_t rs_cell_count;
|
||||
read_to(&rs_cell_count, sizeof(rs_cell_count));
|
||||
|
||||
bool res = true;
|
||||
|
||||
if (seq_id == -1) {
|
||||
llama_past_clear(ctx);
|
||||
res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count);
|
||||
} else {
|
||||
llama_past_seq_rm(ctx, seq_id, -1, -1);
|
||||
// Only a single recurrent cell at most,
|
||||
// because otherwise the cells can be shuffled when a slot is allocated
|
||||
if (rs_cell_count > 1) {
|
||||
LLAMA_LOG_ERROR("%s: too many recurrent state cells for single-sequence session\n", __func__);
|
||||
res = false;
|
||||
}
|
||||
res = res && read_cache_seq_meta(ctx, std::max(kv_cell_count, rs_cell_count), seq_id);
|
||||
}
|
||||
|
||||
res = res && read_kv_cache_data(ctx, kv_cell_count) && read_rs_cache_data(ctx, rs_cell_count);
|
||||
|
||||
if (!res) {
|
||||
if (seq_id == -1) {
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_past_clear(ctx);
|
||||
} else {
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
|
||||
llama_past_seq_rm(ctx, seq_id, -1, -1);
|
||||
}
|
||||
throw std::runtime_error("failed to restore kv cache");
|
||||
}
|
||||
@ -20433,7 +20690,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
|
||||
data_ctx.write_logits(ctx);
|
||||
data_ctx.write_embeddings(ctx);
|
||||
|
||||
data_ctx.write_kv_cache(ctx);
|
||||
data_ctx.write_cache(ctx);
|
||||
|
||||
return data_ctx.get_size_written();
|
||||
}
|
||||
@ -20473,7 +20730,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
|
||||
data_ctx.read_logits(ctx);
|
||||
data_ctx.read_embeddings(ctx);
|
||||
|
||||
data_ctx.read_kv_cache(ctx);
|
||||
data_ctx.read_cache(ctx);
|
||||
|
||||
return data_ctx.get_size_read();
|
||||
}
|
||||
@ -20569,7 +20826,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
|
||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
data_ctx.write_kv_cache(ctx, seq_id);
|
||||
data_ctx.write_cache(ctx, seq_id);
|
||||
|
||||
return data_ctx.get_size_written();
|
||||
}
|
||||
@ -20592,7 +20849,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_
|
||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
data_ctx.read_kv_cache(ctx, dest_seq_id);
|
||||
data_ctx.read_cache(ctx, dest_seq_id);
|
||||
|
||||
return data_ctx.get_size_read();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user