context : prepare kv_cache_read/write to be moved to kv_cache

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-14 12:33:13 +02:00
parent d8b9013108
commit 057e8f5c82
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 76 additions and 78 deletions

View File

@ -928,11 +928,8 @@ struct llama_data_write {
}
}
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_hparams & hparams = ctx->model.hparams;
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
void write_kv_cache_data(const llama_kv_cache & kv, const llama_hparams & hparams, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
const uint32_t v_trans = kv.v_trans ? 1 : 0;
const uint32_t n_layer = hparams.n_layer;
write(&v_trans, sizeof(v_trans));
@ -946,52 +943,52 @@ struct llama_data_write {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Write key type
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
const int32_t k_type_i = (int32_t)kv.k_l[il]->type;
write(&k_type_i, sizeof(k_type_i));
// Write row size of key
const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
const uint64_t k_size_row = ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa);
write(&k_size_row, sizeof(k_size_row));
// Read each range of cells of k_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * k_size_row;
write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
write_tensor_data(kv.k_l[il], range.first * k_size_row, buf_size);
}
}
if (!kv_self.v_trans) {
if (!kv.v_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = (int32_t)kv.v_l[il]->type;
write(&v_type_i, sizeof(v_type_i));
// Write row size of value
const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
const uint64_t v_size_row = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa);
write(&v_size_row, sizeof(v_size_row));
// Read each range of cells of v_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * v_size_row;
write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
write_tensor_data(kv.v_l[il], range.first * v_size_row, buf_size);
}
}
} else {
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
const uint32_t kv_size = kv.size;
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = (int32_t)kv.v_l[il]->type;
write(&v_type_i, sizeof(v_type_i));
// Write element size
const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
const uint32_t v_size_el = ggml_type_size(kv.v_l[il]->type);
write(&v_size_el, sizeof(v_size_el));
// Write GQA embedding size
@ -1004,37 +1001,36 @@ struct llama_data_write {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
const size_t buf_size = range_size * v_size_el;
write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
write_tensor_data(kv.v_l[il], src_offset, buf_size);
}
}
}
}
}
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
const struct llama_kv_cache & kv_self = ctx->kv_self;
void write_kv_cache(const llama_kv_cache & kv, const llama_hparams & hparams, llama_seq_id seq_id = -1) {
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
// Count the number of cells with the specified seq_id
// Find all the ranges of cells with this seq id (or all, when -1)
uint32_t cell_range_begin = kv_self.size;
for (uint32_t i = 0; i < kv_self.size; ++i) {
const auto & cell = kv_self.cells[i];
uint32_t cell_range_begin = kv.size;
for (uint32_t i = 0; i < kv.size; ++i) {
const auto & cell = kv.cells[i];
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
++cell_count;
if (cell_range_begin == kv_self.size) {
if (cell_range_begin == kv.size) {
cell_range_begin = i;
}
} else {
if (cell_range_begin != kv_self.size) {
if (cell_range_begin != kv.size) {
cell_ranges.emplace_back(cell_range_begin, i);
cell_range_begin = kv_self.size;
cell_range_begin = kv.size;
}
}
}
if (cell_range_begin != kv_self.size) {
cell_ranges.emplace_back(cell_range_begin, kv_self.size);
if (cell_range_begin != kv.size) {
cell_ranges.emplace_back(cell_range_begin, kv.size);
}
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
@ -1046,8 +1042,8 @@ struct llama_data_write {
write(&cell_count, sizeof(cell_count));
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
write_kv_cache_data(ctx, cell_ranges);
write_kv_cache_meta(kv, cell_ranges, seq_id);
write_kv_cache_data(kv, hparams, cell_ranges);
}
};
@ -1140,15 +1136,15 @@ struct llama_data_read {
}
}
bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
struct llama_kv_cache & kv_self = ctx->kv_self;
bool read_kv_cache_meta(llama_kv_cache & kv, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
if (dest_seq_id != -1) {
// single sequence
kv_self.seq_rm(dest_seq_id, -1, -1);
kv.seq_rm(dest_seq_id, -1, -1);
llama_sbatch sbatch;
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
batch.n_tokens = cell_count;
batch.n_seq_tokens = cell_count;
batch.n_seqs = 1;
@ -1157,7 +1153,7 @@ struct llama_data_read {
llama_pos pos;
uint32_t n_seq_id;
read_to(&pos, sizeof(pos));
read_to(&pos, sizeof(pos));
read_to(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id != 0) {
@ -1169,30 +1165,30 @@ struct llama_data_read {
}
batch.n_seq_id[0] = 1;
batch.seq_id[0] = &dest_seq_id;
if (!kv_self.find_slot(batch)) {
if (!kv.find_slot(batch)) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
GGML_ASSERT(kv.head + cell_count <= kv.size);
GGML_ASSERT(kv.cells[kv.head].pos == batch.pos[0]);
GGML_ASSERT(kv.cells[kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(kv.cells[kv.head].has_seq_id(dest_seq_id));
GGML_ASSERT(kv.cells[kv.head + cell_count - 1].has_seq_id(dest_seq_id));
} else {
// whole KV cache restore
if (cell_count > kv_self.size) {
if (cell_count > kv.size) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
return false;
}
kv_self.clear();
kv.clear();
for (uint32_t i = 0; i < cell_count; ++i) {
llama_kv_cell & cell = kv_self.cells[i];
llama_kv_cell & cell = kv.cells[i];
llama_pos pos;
uint32_t n_seq_id;
@ -1206,15 +1202,18 @@ struct llama_data_read {
llama_seq_id seq_id;
read_to(&seq_id, sizeof(seq_id));
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
// TODO: llama_kv_cache should have a notion of max sequences
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
if (seq_id < 0) {
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
return false;
}
cell.seq_id.insert(seq_id);
if (kv_self.recurrent) {
int32_t & tail = kv_self.cells[seq_id].tail;
if (kv.recurrent) {
int32_t & tail = kv.cells[seq_id].tail;
if (tail != -1) {
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
return false;
@ -1224,24 +1223,22 @@ struct llama_data_read {
}
}
kv_self.head = 0;
kv_self.used = cell_count;
kv.head = 0;
kv.used = cell_count;
}
if (kv_self.recurrent) {
if (kv.recurrent) {
for (uint32_t i = 0; i < cell_count; ++i) {
uint32_t cell_id = kv_self.head + i;
uint32_t cell_id = kv.head + i;
// make sure the recurrent states will keep their restored state
kv_self.cells[cell_id].src = cell_id;
kv.cells[cell_id].src = cell_id;
}
}
return true;
}
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
const struct llama_hparams & hparams = ctx->model.hparams;
struct llama_kv_cache & kv_self = ctx->kv_self;
bool read_kv_cache_data(llama_kv_cache & kv, const llama_hparams & hparams, uint32_t cell_count) {
uint32_t v_trans;
uint32_t n_layer;
read_to(&v_trans, sizeof(v_trans));
@ -1251,11 +1248,11 @@ struct llama_data_read {
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
return false;
}
if (cell_count > kv_self.size) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
if (cell_count > kv.size) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv.size);
return false;
}
if (kv_self.v_trans != (bool) v_trans) {
if (kv.v_trans != (bool) v_trans) {
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
return false;
}
@ -1267,7 +1264,7 @@ struct llama_data_read {
// Read type of key
int32_t k_type_i_ref;
read_to(&k_type_i_ref, sizeof(k_type_i_ref));
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
const int32_t k_type_i = (int32_t)kv.k_l[il]->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
return false;
@ -1276,7 +1273,7 @@ struct llama_data_read {
// Read row size of key
uint64_t k_size_row_ref;
read_to(&k_size_row_ref, sizeof(k_size_row_ref));
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
const size_t k_size_row = ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa);
if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
return false;
@ -1284,18 +1281,18 @@ struct llama_data_read {
if (cell_count) {
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
ggml_backend_tensor_set(kv.k_l[il], read(cell_count * k_size_row), kv.head * k_size_row, cell_count * k_size_row);
}
}
if (!kv_self.v_trans) {
if (!kv.v_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Read type of value
int32_t v_type_i_ref;
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = (int32_t)kv.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
@ -1304,7 +1301,7 @@ struct llama_data_read {
// Read row size of value
uint64_t v_size_row_ref;
read_to(&v_size_row_ref, sizeof(v_size_row_ref));
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
const size_t v_size_row = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
return false;
@ -1312,7 +1309,7 @@ struct llama_data_read {
if (cell_count) {
// Read and set the values for the whole cell range
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
ggml_backend_tensor_set(kv.v_l[il], read(cell_count * v_size_row), kv.head * v_size_row, cell_count * v_size_row);
}
}
} else {
@ -1323,7 +1320,7 @@ struct llama_data_read {
// Read type of value
int32_t v_type_i_ref;
read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = (int32_t)kv.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
@ -1332,7 +1329,7 @@ struct llama_data_read {
// Read element size of value
uint32_t v_size_el_ref;
read_to(&v_size_el_ref, sizeof(v_size_el_ref));
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
const size_t v_size_el = ggml_type_size(kv.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
return false;
@ -1349,8 +1346,8 @@ struct llama_data_read {
if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
const size_t dst_offset = (kv.head + j * kv.size) * v_size_el;
ggml_backend_tensor_set(kv.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
}
}
}
@ -1358,17 +1355,17 @@ struct llama_data_read {
return true;
}
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
void read_kv_cache(llama_kv_cache & kv, const llama_hparams & hparams, llama_seq_id seq_id = -1) {
uint32_t cell_count;
read_to(&cell_count, sizeof(cell_count));
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
bool res = read_kv_cache_meta(kv, cell_count, seq_id) && read_kv_cache_data(kv, hparams, cell_count);
if (!res) {
if (seq_id == -1) {
ctx->kv_self.clear();
kv.clear();
} else {
ctx->kv_self.seq_rm(seq_id, -1, -1);
kv.seq_rm(seq_id, -1, -1);
}
throw std::runtime_error("failed to restore kv cache");
}
@ -1521,7 +1518,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
data_ctx.write_logits(ctx);
data_ctx.write_embeddings(ctx);
data_ctx.write_kv_cache(ctx);
data_ctx.write_kv_cache(ctx->kv_self, ctx->model.hparams);
return data_ctx.get_size_written();
}
@ -1558,7 +1555,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
data_ctx.read_logits(ctx);
data_ctx.read_embeddings(ctx);
data_ctx.read_kv_cache(ctx);
data_ctx.read_kv_cache(ctx->kv_self, ctx->model.hparams);
return data_ctx.get_size_read();
}
@ -1654,7 +1651,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
llama_synchronize(ctx);
data_ctx.write_kv_cache(ctx, seq_id);
data_ctx.write_kv_cache(ctx->kv_self, ctx->model.hparams, seq_id);
return data_ctx.get_size_written();
}
@ -1677,7 +1674,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
llama_synchronize(ctx);
data_ctx.read_kv_cache(ctx, dest_seq_id);
data_ctx.read_kv_cache(ctx->kv_self, ctx->model.hparams, dest_seq_id);
return data_ctx.get_size_read();
}

View File

@ -44,6 +44,7 @@ struct llama_kv_cache_slot_info {
// ring-buffer of cached KV data
// TODO: pimpl
// TODO: add notion of max sequences
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;