diff --git a/llama-util.h b/llama-util.h index 042ebe43c..3fc03ce28 100644 --- a/llama-util.h +++ b/llama-util.h @@ -149,6 +149,46 @@ struct llama_file { } }; +// llama_context_data +struct llama_data_context { + virtual void write(const void * src, size_t size) = 0; + virtual size_t get_size_written() = 0; + virtual ~llama_data_context() = default; +}; + +struct llama_data_buffer_context : llama_data_context { + uint8_t* ptr; + size_t size_written = 0; + + llama_data_buffer_context(uint8_t * p) : ptr(p) {} + + void write(const void * src, size_t size) override { + memcpy(ptr, src, size); + ptr += size; + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_file_context : llama_data_context { + llama_file* file; + size_t size_written = 0; + + llama_data_file_context(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { LPSTR buf; diff --git a/llama.cpp b/llama.cpp index d427054dd..839739870 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3743,10 +3743,20 @@ size_t llama_get_state_size(const struct llama_context * ctx) { return s_total; } -// Copies the state to the specified destination address -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { - uint8_t * out = dst; - +/** copy state data into either a buffer or file depending on the passed in context + * + * file context: + * llama_file file("/path", "wb"); + * llama_data_file_context data_ctx(&file); + * llama_copy_state_data(ctx, &data_ctx); + * + * buffer context: + * std::vector buf(max_size, 0); + * llama_data_buffer_context data_ctx(&buf.data()); + * llama_copy_state_data(ctx, &data_ctx); + * +*/ +void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { // copy rng { std::stringstream rng_ss; @@ -3758,8 +3768,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); - memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE; + data_ctx->write(&rng_size, sizeof(rng_size)); + data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); } // copy logits @@ -3767,25 +3777,29 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { const size_t logits_cap = ctx->logits.capacity(); const size_t logits_size = ctx->logits.size(); - memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap); - memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size); + data_ctx->write(&logits_cap, sizeof(logits_cap)); + data_ctx->write(&logits_size, sizeof(logits_size)); if (logits_size) { - memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); + data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); } - out += logits_cap * sizeof(float); + // If there is a gap between the size and the capacity, write padding + size_t padding_size = (logits_cap - logits_size) * sizeof(float); + if (padding_size > 0) { + std::vector padding(padding_size, 0); // Create a buffer filled with zeros + data_ctx->write(padding.data(), padding_size); + } } // copy embeddings { const size_t embedding_size = ctx->embedding.size(); - memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size); + data_ctx->write(&embedding_size, sizeof(embedding_size)); if (embedding_size) { - memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); - out += embedding_size * sizeof(float); + data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); } } @@ -3800,8 +3814,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { const size_t kv_size = kv_self.buf.size; const int kv_ntok = llama_get_kv_cache_token_count(ctx); - memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); - memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); + data_ctx->write(&kv_size, sizeof(kv_size)); + data_ctx->write(&kv_ntok, sizeof(kv_ntok)); if (kv_size) { const size_t elt_size = ggml_element_size(kv_self.k); @@ -3810,12 +3824,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { ggml_cgraph gf{}; ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); - kout3d->data = out; - out += ggml_nbytes(kout3d); + std::vector kout3d_data(ggml_nbytes(kout3d), 0); + kout3d->data = kout3d_data.data(); ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); - vout3d->data = out; - out += ggml_nbytes(vout3d); + std::vector vout3d_data(ggml_nbytes(vout3d), 0); + vout3d->data = vout3d_data.data(); ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, n_embd, kv_ntok, n_layer, @@ -3830,15 +3844,20 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); ggml_free(cpy_ctx); + + // our data is now in the kout3d_data and vout3d_data buffers + // write them to file + data_ctx->write(kout3d_data.data(), kout3d_data.size()); + data_ctx->write(vout3d_data.data(), vout3d_data.size()); } } +} - const size_t written = out - dst; - const size_t max_size = llama_get_state_size(ctx); +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + llama_data_buffer_context data_ctx(dst); + llama_copy_state_data_internal(ctx, &data_ctx); - LLAMA_ASSERT(written <= max_size); - - return written; + return data_ctx.get_size_written(); } // Sets the state reading from the specified source address @@ -4023,15 +4042,9 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi file.write_u32((uint32_t) n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count); - // save the context state - { - const size_t n_state_size_max = llama_get_state_size(ctx); - - std::vector state_data(n_state_size_max); - const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data()); - - file.write_raw(state_data.data(), n_state_size_cur); - } + // save the context state using stream saving + llama_data_file_context data_ctx(&file); + llama_copy_state_data_internal(ctx, &data_ctx); return true; }