From b59ddf945e7e8bbe69628e0b6063e8e3fb279329 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 4 Jul 2024 15:55:23 +0300 Subject: [PATCH] llama : fix save/load state --- src/llama.cpp | 44 ++++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index f68c912e6..9ff321b3f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18757,8 +18757,6 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); @@ -18778,6 +18776,9 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data std::vector tmp_buf; for (int il = 0; il < (int) n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head); tmp_buf.resize(k_size); @@ -18910,8 +18911,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); size_t kv_buf_size; uint32_t kv_head; @@ -18943,6 +18942,9 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(kv_self.total_size() >= kv_buf_size); for (int il = 0; il < (int) n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head); ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); @@ -19105,8 +19107,6 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); for (uint32_t i = 0; i < kv_self.size; ++i) { const auto & cell = kv_self.cells[i]; @@ -19117,6 +19117,9 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) } for (int il = 0; il < (int)n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + // types of keys and values s_cell_data_size += sizeof(int32_t) * 2; // k_size_row and v_size_el values of layer @@ -19191,14 +19194,15 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); // Write the layer count data_ctx.write(&n_layer, sizeof(n_layer)); - // Write n_embd_v_gqa - data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + // Write n_embd_v_gqa (reference value) + { + const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s(); + data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + } // Iterate the ranges and write all the pos (this is the token position in the prompt) for (const auto & range : cell_ranges) { @@ -19212,6 +19216,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam // Get whole range at a time std::vector tmp_buf; for (int il = 0; il < (int)n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + // Write key type const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; data_ctx.write(&k_type_i, sizeof(k_type_i)); @@ -19232,6 +19238,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam // TODO: simplify, reduce copy-paste if (!kv_self.v_trans) { for (int il = 0; il < (int)n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; data_ctx.write(&v_type_i, sizeof(v_type_i)); @@ -19252,6 +19260,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam // For the values, they are transposed, so we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (int il = 0; il < (int)n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; data_ctx.write(&v_type_i, sizeof(v_type_i)); @@ -19320,14 +19330,14 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // Sanity check model compatibility const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + if (n_layer != n_layer_ref) { LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref); return 0; } - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref); + + if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref); return 0; } @@ -19367,6 +19377,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo for (int il = 0; il < (int)n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + // Read type of key int32_t k_type_i_ref; memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref)); @@ -19399,6 +19411,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // TODO: simplify, reduce copy-paste if (!kv_self.v_trans) { for (int il = 0; il < (int)n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + // Read type of value int32_t v_type_i_ref; memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); @@ -19430,6 +19444,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } else { // For each layer, read the values for each cell (transposed) for (int il = 0; il < (int)n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + // Read type of value int32_t v_type_i_ref; memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));