llama : fix save/load state

This commit is contained in:
Georgi Gerganov 2024-07-04 15:55:23 +03:00
parent 29ab5a0ed1
commit b59ddf945e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -18757,8 +18757,6 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
@ -18778,6 +18776,9 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int) n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
tmp_buf.resize(k_size);
@ -18910,8 +18911,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
size_t kv_buf_size;
uint32_t kv_head;
@ -18943,6 +18942,9 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
for (int il = 0; il < (int) n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
@ -19105,8 +19107,6 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
for (uint32_t i = 0; i < kv_self.size; ++i) {
const auto & cell = kv_self.cells[i];
@ -19117,6 +19117,9 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
}
for (int il = 0; il < (int)n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// types of keys and values
s_cell_data_size += sizeof(int32_t) * 2;
// k_size_row and v_size_el values of layer
@ -19191,14 +19194,15 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
// Write the layer count
data_ctx.write(&n_layer, sizeof(n_layer));
// Write n_embd_v_gqa
data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
// Write n_embd_v_gqa (reference value)
{
const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
}
// Iterate the ranges and write all the pos (this is the token position in the prompt)
for (const auto & range : cell_ranges) {
@ -19212,6 +19216,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
// Get whole range at a time
std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int)n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Write key type
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
data_ctx.write(&k_type_i, sizeof(k_type_i));
@ -19232,6 +19238,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
// TODO: simplify, reduce copy-paste
if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));
@ -19252,6 +19260,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (int il = 0; il < (int)n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));
@ -19320,14 +19330,14 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
// Sanity check model compatibility
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
if (n_layer != n_layer_ref) {
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
return 0;
}
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
return 0;
}
@ -19367,6 +19377,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
for (int il = 0; il < (int)n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Read type of key
int32_t k_type_i_ref;
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
@ -19399,6 +19411,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
// TODO: simplify, reduce copy-paste
if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
@ -19430,6 +19444,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
} else {
// For each layer, read the values for each cell (transposed)
for (int il = 0; il < (int)n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));