mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
Stream save llama context data to file instead of allocating entire buffer upfront (#2488)
* added stream saving context data to file to avoid allocating unnecessary amounts of memory * generalised copying state data to file or buffer * added comments explaining how copy_state_data works * fixed trailing whitespaces * fixed save load state example * updated save load state to use public function in llama.cpp * - restored breakage of the llama_copy_state_data API - moved new logic for copying llama state data to internal function * fixed function declaration order * restored save load state example * fixed whitepace * removed unused llama-util.h include * Apply suggestions from code review Co-authored-by: slaren <slarengh@gmail.com> * Apply code review suggestions Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
ff966e7ca6
commit
415e99fec2
40
llama-util.h
40
llama-util.h
@ -149,6 +149,46 @@ struct llama_file {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// llama_context_data
|
||||||
|
struct llama_data_context {
|
||||||
|
virtual void write(const void * src, size_t size) = 0;
|
||||||
|
virtual size_t get_size_written() = 0;
|
||||||
|
virtual ~llama_data_context() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_data_buffer_context : llama_data_context {
|
||||||
|
uint8_t* ptr;
|
||||||
|
size_t size_written = 0;
|
||||||
|
|
||||||
|
llama_data_buffer_context(uint8_t * p) : ptr(p) {}
|
||||||
|
|
||||||
|
void write(const void * src, size_t size) override {
|
||||||
|
memcpy(ptr, src, size);
|
||||||
|
ptr += size;
|
||||||
|
size_written += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_size_written() override {
|
||||||
|
return size_written;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_data_file_context : llama_data_context {
|
||||||
|
llama_file* file;
|
||||||
|
size_t size_written = 0;
|
||||||
|
|
||||||
|
llama_data_file_context(llama_file * f) : file(f) {}
|
||||||
|
|
||||||
|
void write(const void * src, size_t size) override {
|
||||||
|
file->write_raw(src, size);
|
||||||
|
size_written += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_size_written() override {
|
||||||
|
return size_written;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
static std::string llama_format_win_err(DWORD err) {
|
static std::string llama_format_win_err(DWORD err) {
|
||||||
LPSTR buf;
|
LPSTR buf;
|
||||||
|
79
llama.cpp
79
llama.cpp
@ -3743,10 +3743,20 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||||||
return s_total;
|
return s_total;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copies the state to the specified destination address
|
/** copy state data into either a buffer or file depending on the passed in context
|
||||||
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
*
|
||||||
uint8_t * out = dst;
|
* file context:
|
||||||
|
* llama_file file("/path", "wb");
|
||||||
|
* llama_data_file_context data_ctx(&file);
|
||||||
|
* llama_copy_state_data(ctx, &data_ctx);
|
||||||
|
*
|
||||||
|
* buffer context:
|
||||||
|
* std::vector<uint8_t> buf(max_size, 0);
|
||||||
|
* llama_data_buffer_context data_ctx(&buf.data());
|
||||||
|
* llama_copy_state_data(ctx, &data_ctx);
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
||||||
// copy rng
|
// copy rng
|
||||||
{
|
{
|
||||||
std::stringstream rng_ss;
|
std::stringstream rng_ss;
|
||||||
@ -3758,8 +3768,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
|||||||
memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
|
memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
|
||||||
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
|
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
|
||||||
|
|
||||||
memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
|
data_ctx->write(&rng_size, sizeof(rng_size));
|
||||||
memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
|
data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy logits
|
// copy logits
|
||||||
@ -3767,25 +3777,29 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
|||||||
const size_t logits_cap = ctx->logits.capacity();
|
const size_t logits_cap = ctx->logits.capacity();
|
||||||
const size_t logits_size = ctx->logits.size();
|
const size_t logits_size = ctx->logits.size();
|
||||||
|
|
||||||
memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap);
|
data_ctx->write(&logits_cap, sizeof(logits_cap));
|
||||||
memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size);
|
data_ctx->write(&logits_size, sizeof(logits_size));
|
||||||
|
|
||||||
if (logits_size) {
|
if (logits_size) {
|
||||||
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
|
data_ctx->write(ctx->logits.data(), logits_size * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
out += logits_cap * sizeof(float);
|
// If there is a gap between the size and the capacity, write padding
|
||||||
|
size_t padding_size = (logits_cap - logits_size) * sizeof(float);
|
||||||
|
if (padding_size > 0) {
|
||||||
|
std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros
|
||||||
|
data_ctx->write(padding.data(), padding_size);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy embeddings
|
// copy embeddings
|
||||||
{
|
{
|
||||||
const size_t embedding_size = ctx->embedding.size();
|
const size_t embedding_size = ctx->embedding.size();
|
||||||
|
|
||||||
memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);
|
data_ctx->write(&embedding_size, sizeof(embedding_size));
|
||||||
|
|
||||||
if (embedding_size) {
|
if (embedding_size) {
|
||||||
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
|
data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float));
|
||||||
out += embedding_size * sizeof(float);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3800,8 +3814,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
|||||||
const size_t kv_size = kv_self.buf.size;
|
const size_t kv_size = kv_self.buf.size;
|
||||||
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
||||||
|
|
||||||
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
|
data_ctx->write(&kv_size, sizeof(kv_size));
|
||||||
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
|
data_ctx->write(&kv_ntok, sizeof(kv_ntok));
|
||||||
|
|
||||||
if (kv_size) {
|
if (kv_size) {
|
||||||
const size_t elt_size = ggml_element_size(kv_self.k);
|
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||||
@ -3810,12 +3824,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
|||||||
ggml_cgraph gf{};
|
ggml_cgraph gf{};
|
||||||
|
|
||||||
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
||||||
kout3d->data = out;
|
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
|
||||||
out += ggml_nbytes(kout3d);
|
kout3d->data = kout3d_data.data();
|
||||||
|
|
||||||
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
||||||
vout3d->data = out;
|
std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
|
||||||
out += ggml_nbytes(vout3d);
|
vout3d->data = vout3d_data.data();
|
||||||
|
|
||||||
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
||||||
n_embd, kv_ntok, n_layer,
|
n_embd, kv_ntok, n_layer,
|
||||||
@ -3830,15 +3844,20 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
|||||||
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
|
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
|
||||||
|
|
||||||
ggml_free(cpy_ctx);
|
ggml_free(cpy_ctx);
|
||||||
|
|
||||||
|
// our data is now in the kout3d_data and vout3d_data buffers
|
||||||
|
// write them to file
|
||||||
|
data_ctx->write(kout3d_data.data(), kout3d_data.size());
|
||||||
|
data_ctx->write(vout3d_data.data(), vout3d_data.size());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t written = out - dst;
|
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
||||||
const size_t max_size = llama_get_state_size(ctx);
|
llama_data_buffer_context data_ctx(dst);
|
||||||
|
llama_copy_state_data_internal(ctx, &data_ctx);
|
||||||
|
|
||||||
LLAMA_ASSERT(written <= max_size);
|
return data_ctx.get_size_written();
|
||||||
|
|
||||||
return written;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets the state reading from the specified source address
|
// Sets the state reading from the specified source address
|
||||||
@ -4023,15 +4042,9 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
|||||||
file.write_u32((uint32_t) n_token_count);
|
file.write_u32((uint32_t) n_token_count);
|
||||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||||
|
|
||||||
// save the context state
|
// save the context state using stream saving
|
||||||
{
|
llama_data_file_context data_ctx(&file);
|
||||||
const size_t n_state_size_max = llama_get_state_size(ctx);
|
llama_copy_state_data_internal(ctx, &data_ctx);
|
||||||
|
|
||||||
std::vector<uint8_t> state_data(n_state_size_max);
|
|
||||||
const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());
|
|
||||||
|
|
||||||
file.write_raw(state_data.data(), n_state_size_cur);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user