mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 11:23:56 +01:00
llama : fix save/load state
This commit is contained in:
parent
29ab5a0ed1
commit
b59ddf945e
@ -18757,8 +18757,6 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
|
|||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
|
||||||
|
|
||||||
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
|
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
|
||||||
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
|
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
|
||||||
@ -18778,6 +18776,9 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
|
|||||||
|
|
||||||
std::vector<uint8_t> tmp_buf;
|
std::vector<uint8_t> tmp_buf;
|
||||||
for (int il = 0; il < (int) n_layer; ++il) {
|
for (int il = 0; il < (int) n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
|
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
|
||||||
|
|
||||||
tmp_buf.resize(k_size);
|
tmp_buf.resize(k_size);
|
||||||
@ -18910,8 +18911,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
|||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
|
||||||
|
|
||||||
size_t kv_buf_size;
|
size_t kv_buf_size;
|
||||||
uint32_t kv_head;
|
uint32_t kv_head;
|
||||||
@ -18943,6 +18942,9 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
|||||||
GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
|
GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
|
||||||
|
|
||||||
for (int il = 0; il < (int) n_layer; ++il) {
|
for (int il = 0; il < (int) n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
|
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
|
||||||
|
|
||||||
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
|
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
|
||||||
@ -19105,8 +19107,6 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
|
|||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||||
const auto & cell = kv_self.cells[i];
|
const auto & cell = kv_self.cells[i];
|
||||||
@ -19117,6 +19117,9 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
// types of keys and values
|
// types of keys and values
|
||||||
s_cell_data_size += sizeof(int32_t) * 2;
|
s_cell_data_size += sizeof(int32_t) * 2;
|
||||||
// k_size_row and v_size_el values of layer
|
// k_size_row and v_size_el values of layer
|
||||||
@ -19191,14 +19194,15 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
|||||||
|
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
|
||||||
|
|
||||||
// Write the layer count
|
// Write the layer count
|
||||||
data_ctx.write(&n_layer, sizeof(n_layer));
|
data_ctx.write(&n_layer, sizeof(n_layer));
|
||||||
|
|
||||||
// Write n_embd_v_gqa
|
// Write n_embd_v_gqa (reference value)
|
||||||
data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
{
|
||||||
|
const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
|
||||||
|
data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||||
|
}
|
||||||
|
|
||||||
// Iterate the ranges and write all the pos (this is the token position in the prompt)
|
// Iterate the ranges and write all the pos (this is the token position in the prompt)
|
||||||
for (const auto & range : cell_ranges) {
|
for (const auto & range : cell_ranges) {
|
||||||
@ -19212,6 +19216,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
|||||||
// Get whole range at a time
|
// Get whole range at a time
|
||||||
std::vector<uint8_t> tmp_buf;
|
std::vector<uint8_t> tmp_buf;
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||||
|
|
||||||
// Write key type
|
// Write key type
|
||||||
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
||||||
data_ctx.write(&k_type_i, sizeof(k_type_i));
|
data_ctx.write(&k_type_i, sizeof(k_type_i));
|
||||||
@ -19232,6 +19238,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
|||||||
// TODO: simplify, reduce copy-paste
|
// TODO: simplify, reduce copy-paste
|
||||||
if (!kv_self.v_trans) {
|
if (!kv_self.v_trans) {
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||||
@ -19252,6 +19260,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
|||||||
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
|
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
|
||||||
const uint32_t kv_size = kv_self.size;
|
const uint32_t kv_size = kv_self.size;
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||||
@ -19320,14 +19330,14 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
|||||||
// Sanity check model compatibility
|
// Sanity check model compatibility
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
|
||||||
if (n_layer != n_layer_ref) {
|
if (n_layer != n_layer_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
|
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
|
||||||
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
|
if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -19367,6 +19377,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
|||||||
|
|
||||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
|
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||||
|
|
||||||
// Read type of key
|
// Read type of key
|
||||||
int32_t k_type_i_ref;
|
int32_t k_type_i_ref;
|
||||||
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
|
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
|
||||||
@ -19399,6 +19411,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
|||||||
// TODO: simplify, reduce copy-paste
|
// TODO: simplify, reduce copy-paste
|
||||||
if (!kv_self.v_trans) {
|
if (!kv_self.v_trans) {
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||||
@ -19430,6 +19444,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
|||||||
} else {
|
} else {
|
||||||
// For each layer, read the values for each cell (transposed)
|
// For each layer, read the values for each cell (transposed)
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||||
|
Loading…
Reference in New Issue
Block a user