diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 48d801110..ef952e2bd 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -45,13 +45,13 @@ int main(int argc, char ** argv) { // save state (rng, logits, embedding and kv_cache) to file { std::vector state_mem(llama_get_state_size(ctx)); + const size_t written = llama_copy_state_data(ctx, state_mem.data()); - { - FILE *fp_write = fopen("dump_state.bin", "wb"); - llama_copy_state_data(ctx, state_mem.data()); // could also copy directly to memory mapped file - fwrite(state_mem.data(), 1, state_mem.size(), fp_write); - fclose(fp_write); - } + FILE *fp_write = fopen("dump_state.bin", "wb"); + fwrite(state_mem.data(), 1, written, fp_write); + fclose(fp_write); + + fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size()); } // save state (last tokens) @@ -100,18 +100,17 @@ int main(int argc, char ** argv) { std::vector state_mem(llama_get_state_size(ctx2)); FILE * fp_read = fopen("dump_state.bin", "rb"); + const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); + fclose(fp_read); - const size_t ret = fread(state_mem.data(), 1, state_mem.size(), fp_read); - if (ret != state_mem.size()) { + if (read != llama_set_state_data(ctx2, state_mem.data())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); llama_free(ctx2); llama_free_model(model); return 1; } - llama_set_state_data(ctx2, state_mem.data()); - - fclose(fp_read); + fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); } // restore state (last tokens) diff --git a/llama.cpp b/llama.cpp index 1d2eb569f..275456088 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9379,12 +9379,8 @@ struct llama_context * llama_new_context_with_model( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } - // resized during inference - if (params.logits_all) { - ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab); - } else { - ctx->logits.reserve(hparams.n_vocab); - } + // resized during inference, reserve maximum + ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); if (params.embedding){ ctx->embedding.resize(hparams.n_embd); @@ -9731,8 +9727,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) { // for reference, std::mt19937(1337) serializes to 6701 bytes. const size_t s_rng_size = sizeof(size_t); const size_t s_rng = LLAMA_MAX_RNG_STATE; - const size_t s_logits_capacity = sizeof(size_t); const size_t s_logits_size = sizeof(size_t); + // assume worst case for logits although only currently set ones are serialized const size_t s_logits = ctx->logits.capacity() * sizeof(float); const size_t s_embedding_size = sizeof(size_t); const size_t s_embedding = ctx->embedding.size() * sizeof(float); @@ -9743,7 +9739,6 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_total = ( + s_rng_size + s_rng - + s_logits_capacity + s_logits_size + s_logits + s_embedding_size @@ -9812,37 +9807,27 @@ struct llama_data_file_context : llama_data_context { static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { // copy rng { - std::stringstream rng_ss; + std::ostringstream rng_ss; rng_ss << ctx->rng; - const size_t rng_size = rng_ss.str().size(); - char rng_buf[LLAMA_MAX_RNG_STATE]; + const std::string & rng_str = rng_ss.str(); + const size_t rng_size = rng_str.size(); - memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); - memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); - data_ctx->write(&rng_size, sizeof(rng_size)); - data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); + data_ctx->write(&rng_size, sizeof(rng_size)); + data_ctx->write(rng_str.data(), rng_size); } // copy logits { - const size_t logits_cap = ctx->logits.capacity(); const size_t logits_size = ctx->logits.size(); - data_ctx->write(&logits_cap, sizeof(logits_cap)); data_ctx->write(&logits_size, sizeof(logits_size)); if (logits_size) { data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); } - - // If there is a gap between the size and the capacity, write padding - size_t padding_size = (logits_cap - logits_size) * sizeof(float); - if (padding_size > 0) { - std::vector padding(padding_size, 0); // Create a buffer filled with zeros - data_ctx->write(padding.data(), padding_size); - } } // copy embeddings @@ -9925,13 +9910,13 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { // set rng { size_t rng_size; - char rng_buf[LLAMA_MAX_RNG_STATE]; + memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); - memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); - memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE; + GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); - std::stringstream rng_ss; - rng_ss.str(std::string(&rng_buf[0], rng_size)); + std::string rng_str((char *)inp, rng_size); inp += rng_size; + + std::istringstream rng_ss(rng_str); rng_ss >> ctx->rng; GGML_ASSERT(!rng_ss.fail()); @@ -9939,20 +9924,18 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { // set logits { - size_t logits_cap; size_t logits_size; - memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); - GGML_ASSERT(ctx->logits.capacity() == logits_cap); + GGML_ASSERT(ctx->logits.capacity() >= logits_size); if (logits_size) { ctx->logits.resize(logits_size); - memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); - } - inp += logits_cap * sizeof(float); + memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); + inp += logits_size * sizeof(float); + } } // set embeddings diff --git a/llama.h b/llama.h index 689e12d7c..01d6fafaa 100644 --- a/llama.h +++ b/llama.h @@ -43,7 +43,7 @@ #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 3 +#define LLAMA_SESSION_VERSION 4 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) // Defined when llama.cpp is compiled with support for offloading model layers to GPU.