diff --git a/llama.cpp b/llama.cpp index 8c1d65778..bc0ef1281 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2072,35 +2072,191 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor } } -// Returns the KV cache that will contain the context for the -// ongoing prediction with the model. -const uint8_t * llama_get_kv_cache(struct llama_context * ctx) { - return ctx->model.kv_self.buf.addr; -} - -// Returns the size of the KV cache -size_t llama_get_kv_cache_size(struct llama_context * ctx) { - return ctx->model.kv_self.buf.size; -} - int llama_get_kv_cache_token_count(struct llama_context * ctx) { return ctx->model.kv_self.n; } -// Sets the KV cache containing the current context for the model -void llama_set_kv_cache( - struct llama_context * ctx, - const uint8_t * kv_cache, - size_t n_size, - int n_token_count) { - // Make sure we have the same kv cache setup - LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size); - void * k_data = ctx->model.kv_self.k->data; // remember data pointers - void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy - memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size); - ctx->model.kv_self.k->data = k_data; // restore correct data pointers - ctx->model.kv_self.v->data = v_data; - ctx->model.kv_self.n = n_token_count; +#define LLAMA_MAX_RNG_STATE 64*1024 + +// Returns the size of the state +size_t llama_get_state_size(struct llama_context * ctx) { + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = LLAMA_MAX_RNG_STATE; + const size_t s_logits_capacity = sizeof(size_t); + const size_t s_logits_size = sizeof(size_t); + const size_t s_logits = ctx->logits.capacity() * sizeof(float); + const size_t s_embedding_size = sizeof(size_t); + const size_t s_embedding = ctx->embedding.size() * sizeof(float); + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = ctx->model.kv_self.buf.size; + + const size_t s_total = ( + + s_rng_size + + s_rng + + s_logits_capacity + + s_logits_size + + s_logits + + s_embedding_size + + s_embedding + + s_kv_size + + s_kv_ntok + + s_kv + ); + + return s_total; +} + +// Copies the state to the specified destination address +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { + uint8_t * out = dest; + + // copy rng + { + std::stringstream rng_ss; + rng_ss << ctx->rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size); + memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE; + } + + // copy logits + { + const size_t logits_cap = ctx->logits.capacity(); + const size_t logits_size = ctx->logits.size(); + + memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap); + memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size); + + if (logits_size) { + memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); + } + + out += logits_cap * sizeof(float); + } + + // copy embeddings + { + const size_t embedding_size = ctx->embedding.size(); + + memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size); + + if (embedding_size) { + memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); + out += embedding_size * sizeof(float); + } + } + + // copy kv cache + { + const size_t kv_size = ctx->model.kv_self.buf.size; + const int kv_ntok = llama_get_kv_cache_token_count(ctx); + + memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size); + memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok); + + if (kv_size) { + memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size; + } + } + + const size_t written = out - dest; + const size_t expected = llama_get_state_size(ctx); + + LLAMA_ASSERT(written == expected); + + return written; +} + +// Sets the state reading from the specified source address +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { + const uint8_t * in = src; + + // set rng + { + size_t rng_size; + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size); + memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE; + + std::stringstream rng_ss; + rng_ss.str(std::string(&rng_buf[0], rng_size)); + rng_ss >> ctx->rng; + + LLAMA_ASSERT(rng_ss.fail() == false); + } + + // set logits + { + size_t logits_cap; + size_t logits_size; + + memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap); + memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size); + + LLAMA_ASSERT(ctx->logits.capacity() == logits_cap); + + if (logits_size) { + ctx->logits.resize(logits_size); + memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); + } + + in += logits_cap * sizeof(float); + } + + // set embeddings + { + size_t embedding_size; + + memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size); + + LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); + + if (embedding_size) { + memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); + in += embedding_size * sizeof(float); + } + } + + // set kv cache + { + size_t kv_size; + int kv_ntok; + + memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size); + memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok); + + if (kv_size) { + LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); + + void * k_data = ctx->model.kv_self.k->data; // remember data pointers + void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy + + memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size; + + ctx->model.kv_self.k->data = k_data; // restore correct data pointers + ctx->model.kv_self.v->data = v_data; + + } + + ctx->model.kv_self.n = kv_ntok; + } + + const size_t nread = in - src; + const size_t expected = llama_get_state_size(ctx); + + LLAMA_ASSERT(nread == expected); + + return nread; } int llama_eval( @@ -2256,120 +2412,3 @@ std::vector>& llama_internal_get_te return ctx->model.tensors_by_name; } -// Returns the size of the state -size_t llama_get_state_size(struct llama_context * ctx) { - // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. - // for reference, std::mt19937(1337) serializes to 6701 bytes. - const size_t s_rng_size = sizeof(size_t); - const size_t s_rng = 64*1024; - const size_t s_logits_capacity = sizeof(size_t); - const size_t s_logits_size = sizeof(size_t); - const size_t s_logits = ctx->logits.capacity() * sizeof(float); - const size_t s_embedding_size = sizeof(size_t); - const size_t s_embedding = ctx->embedding.size() * sizeof(float); - const size_t s_kv_size = sizeof(size_t); - const size_t s_kv_ntok = sizeof(int); - const size_t s_kv = llama_get_kv_cache_size(ctx); - const size_t s_total = ( - + s_rng_size - + s_rng - + s_logits_capacity - + s_logits_size - + s_logits - + s_embedding_size - + s_embedding - + s_kv_size - + s_kv_ntok - + s_kv - ); - return s_total; -} - -// Copies the state to the specified destination address -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { - std::stringstream rng_ss; - rng_ss << ctx->rng; - const size_t rng_size = rng_ss.str().size(); - char rng_buf[64*1024]; - memset(&rng_buf[0], 0, 64*1024); - memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - const size_t logits_capacity = ctx->logits.capacity(); - const size_t logits_size = ctx->logits.size(); - const size_t embedding_size = ctx->embedding.size(); - const size_t kv_size = llama_get_kv_cache_size(ctx); - const int kv_ntok = llama_get_kv_cache_token_count(ctx); - - uint8_t * out = dest; - memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t); - memcpy(out, &rng_buf[0], 64*1024); out += 64*1024; - memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t); - memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t); - if (logits_size) { - memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); - } - out += logits_capacity * sizeof(float); - memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t); - if (embedding_size) { - memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float); - } - memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t); - memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int); - if (kv_size) { - memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size; - } - const size_t written = out - dest; - const size_t expected = llama_get_state_size(ctx); - LLAMA_ASSERT(written == expected); - return written; -} - -// Sets the state reading from the specified source address -size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { - size_t rng_size; - char rng_buf[64*1024]; - std::stringstream rng_ss; - - const uint8_t * in = src; - memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t); - memcpy(&rng_buf[0], in, 64*1024); in += 64*1024; - rng_ss.str(std::string(&rng_buf[0], rng_size)); - rng_ss >> ctx->rng; - LLAMA_ASSERT(rng_ss.fail() == false); - - size_t logits_capacity; - size_t logits_size; - size_t embedding_size; - size_t kv_size; - int kv_ntok; - - memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t); - memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t); - LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity); - if (logits_size) { - ctx->logits.resize(logits_size); - memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); - } - in += logits_capacity * sizeof(float); - memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t); - LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); - if (embedding_size) { - memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); - in += embedding_size * sizeof(float); - } - memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t); - memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int); - if (kv_size) { - LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); - void * k_data = ctx->model.kv_self.k->data; // remember data pointers - void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy - memcpy(ctx->model.kv_self.buf.addr, in, kv_size); - ctx->model.kv_self.k->data = k_data; // restore correct data pointers - ctx->model.kv_self.v->data = v_data; - in += kv_size; - } - ctx->model.kv_self.n = kv_ntok; - const size_t nread = in - src; - const size_t expected = llama_get_state_size(ctx); - LLAMA_ASSERT(nread == expected); - return nread; -} diff --git a/llama.h b/llama.h index f68a0cb40..e9e3abea5 100644 --- a/llama.h +++ b/llama.h @@ -112,23 +112,9 @@ extern "C" { const char * path_base_model, int n_threads); - // Returns the KV cache that will contain the context for the - // ongoing prediction with the model. - LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx); - - // Returns the size of the KV cache - LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx); - // Returns the number of tokens in the KV cache LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx); - // Sets the KV cache containing the current context for the model - LLAMA_API void llama_set_kv_cache( - struct llama_context * ctx, - const uint8_t * kv_cache, - size_t n_size, - int n_token_count); - // Returns the size in bytes of the state (rng, logits, embedding and kv_cache) LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);