mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
llama : refactor get / set state + remove redundant kv cache API (#1143)
This commit is contained in:
parent
1d78fecdab
commit
c4fe84fb0d
315
llama.cpp
315
llama.cpp
@ -2072,35 +2072,191 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the KV cache that will contain the context for the
|
|
||||||
// ongoing prediction with the model.
|
|
||||||
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
|
|
||||||
return ctx->model.kv_self.buf.addr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the size of the KV cache
|
|
||||||
size_t llama_get_kv_cache_size(struct llama_context * ctx) {
|
|
||||||
return ctx->model.kv_self.buf.size;
|
|
||||||
}
|
|
||||||
|
|
||||||
int llama_get_kv_cache_token_count(struct llama_context * ctx) {
|
int llama_get_kv_cache_token_count(struct llama_context * ctx) {
|
||||||
return ctx->model.kv_self.n;
|
return ctx->model.kv_self.n;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets the KV cache containing the current context for the model
|
#define LLAMA_MAX_RNG_STATE 64*1024
|
||||||
void llama_set_kv_cache(
|
|
||||||
struct llama_context * ctx,
|
// Returns the size of the state
|
||||||
const uint8_t * kv_cache,
|
size_t llama_get_state_size(struct llama_context * ctx) {
|
||||||
size_t n_size,
|
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||||
int n_token_count) {
|
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
||||||
// Make sure we have the same kv cache setup
|
const size_t s_rng_size = sizeof(size_t);
|
||||||
LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size);
|
const size_t s_rng = LLAMA_MAX_RNG_STATE;
|
||||||
|
const size_t s_logits_capacity = sizeof(size_t);
|
||||||
|
const size_t s_logits_size = sizeof(size_t);
|
||||||
|
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
||||||
|
const size_t s_embedding_size = sizeof(size_t);
|
||||||
|
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
|
||||||
|
const size_t s_kv_size = sizeof(size_t);
|
||||||
|
const size_t s_kv_ntok = sizeof(int);
|
||||||
|
const size_t s_kv = ctx->model.kv_self.buf.size;
|
||||||
|
|
||||||
|
const size_t s_total = (
|
||||||
|
+ s_rng_size
|
||||||
|
+ s_rng
|
||||||
|
+ s_logits_capacity
|
||||||
|
+ s_logits_size
|
||||||
|
+ s_logits
|
||||||
|
+ s_embedding_size
|
||||||
|
+ s_embedding
|
||||||
|
+ s_kv_size
|
||||||
|
+ s_kv_ntok
|
||||||
|
+ s_kv
|
||||||
|
);
|
||||||
|
|
||||||
|
return s_total;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copies the state to the specified destination address
|
||||||
|
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
||||||
|
uint8_t * out = dest;
|
||||||
|
|
||||||
|
// copy rng
|
||||||
|
{
|
||||||
|
std::stringstream rng_ss;
|
||||||
|
rng_ss << ctx->rng;
|
||||||
|
|
||||||
|
const size_t rng_size = rng_ss.str().size();
|
||||||
|
char rng_buf[LLAMA_MAX_RNG_STATE];
|
||||||
|
|
||||||
|
memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
|
||||||
|
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
|
||||||
|
|
||||||
|
memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
|
||||||
|
memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy logits
|
||||||
|
{
|
||||||
|
const size_t logits_cap = ctx->logits.capacity();
|
||||||
|
const size_t logits_size = ctx->logits.size();
|
||||||
|
|
||||||
|
memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap);
|
||||||
|
memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size);
|
||||||
|
|
||||||
|
if (logits_size) {
|
||||||
|
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
out += logits_cap * sizeof(float);
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy embeddings
|
||||||
|
{
|
||||||
|
const size_t embedding_size = ctx->embedding.size();
|
||||||
|
|
||||||
|
memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);
|
||||||
|
|
||||||
|
if (embedding_size) {
|
||||||
|
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
|
||||||
|
out += embedding_size * sizeof(float);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy kv cache
|
||||||
|
{
|
||||||
|
const size_t kv_size = ctx->model.kv_self.buf.size;
|
||||||
|
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
||||||
|
|
||||||
|
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
|
||||||
|
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
|
||||||
|
|
||||||
|
if (kv_size) {
|
||||||
|
memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t written = out - dest;
|
||||||
|
const size_t expected = llama_get_state_size(ctx);
|
||||||
|
|
||||||
|
LLAMA_ASSERT(written == expected);
|
||||||
|
|
||||||
|
return written;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets the state reading from the specified source address
|
||||||
|
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
|
const uint8_t * in = src;
|
||||||
|
|
||||||
|
// set rng
|
||||||
|
{
|
||||||
|
size_t rng_size;
|
||||||
|
char rng_buf[LLAMA_MAX_RNG_STATE];
|
||||||
|
|
||||||
|
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
|
||||||
|
memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
|
||||||
|
|
||||||
|
std::stringstream rng_ss;
|
||||||
|
rng_ss.str(std::string(&rng_buf[0], rng_size));
|
||||||
|
rng_ss >> ctx->rng;
|
||||||
|
|
||||||
|
LLAMA_ASSERT(rng_ss.fail() == false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// set logits
|
||||||
|
{
|
||||||
|
size_t logits_cap;
|
||||||
|
size_t logits_size;
|
||||||
|
|
||||||
|
memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap);
|
||||||
|
memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
|
||||||
|
|
||||||
|
LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
|
||||||
|
|
||||||
|
if (logits_size) {
|
||||||
|
ctx->logits.resize(logits_size);
|
||||||
|
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
in += logits_cap * sizeof(float);
|
||||||
|
}
|
||||||
|
|
||||||
|
// set embeddings
|
||||||
|
{
|
||||||
|
size_t embedding_size;
|
||||||
|
|
||||||
|
memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
|
||||||
|
|
||||||
|
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
|
||||||
|
|
||||||
|
if (embedding_size) {
|
||||||
|
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
|
||||||
|
in += embedding_size * sizeof(float);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set kv cache
|
||||||
|
{
|
||||||
|
size_t kv_size;
|
||||||
|
int kv_ntok;
|
||||||
|
|
||||||
|
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
|
||||||
|
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
|
||||||
|
|
||||||
|
if (kv_size) {
|
||||||
|
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
|
||||||
|
|
||||||
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
||||||
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
||||||
memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size);
|
|
||||||
|
memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
|
||||||
|
|
||||||
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
||||||
ctx->model.kv_self.v->data = v_data;
|
ctx->model.kv_self.v->data = v_data;
|
||||||
ctx->model.kv_self.n = n_token_count;
|
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->model.kv_self.n = kv_ntok;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t nread = in - src;
|
||||||
|
const size_t expected = llama_get_state_size(ctx);
|
||||||
|
|
||||||
|
LLAMA_ASSERT(nread == expected);
|
||||||
|
|
||||||
|
return nread;
|
||||||
}
|
}
|
||||||
|
|
||||||
int llama_eval(
|
int llama_eval(
|
||||||
@ -2256,120 +2412,3 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
|
|||||||
return ctx->model.tensors_by_name;
|
return ctx->model.tensors_by_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the size of the state
|
|
||||||
size_t llama_get_state_size(struct llama_context * ctx) {
|
|
||||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
|
||||||
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
|
||||||
const size_t s_rng_size = sizeof(size_t);
|
|
||||||
const size_t s_rng = 64*1024;
|
|
||||||
const size_t s_logits_capacity = sizeof(size_t);
|
|
||||||
const size_t s_logits_size = sizeof(size_t);
|
|
||||||
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
|
||||||
const size_t s_embedding_size = sizeof(size_t);
|
|
||||||
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
|
|
||||||
const size_t s_kv_size = sizeof(size_t);
|
|
||||||
const size_t s_kv_ntok = sizeof(int);
|
|
||||||
const size_t s_kv = llama_get_kv_cache_size(ctx);
|
|
||||||
const size_t s_total = (
|
|
||||||
+ s_rng_size
|
|
||||||
+ s_rng
|
|
||||||
+ s_logits_capacity
|
|
||||||
+ s_logits_size
|
|
||||||
+ s_logits
|
|
||||||
+ s_embedding_size
|
|
||||||
+ s_embedding
|
|
||||||
+ s_kv_size
|
|
||||||
+ s_kv_ntok
|
|
||||||
+ s_kv
|
|
||||||
);
|
|
||||||
return s_total;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copies the state to the specified destination address
|
|
||||||
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
|
||||||
std::stringstream rng_ss;
|
|
||||||
rng_ss << ctx->rng;
|
|
||||||
const size_t rng_size = rng_ss.str().size();
|
|
||||||
char rng_buf[64*1024];
|
|
||||||
memset(&rng_buf[0], 0, 64*1024);
|
|
||||||
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
|
|
||||||
const size_t logits_capacity = ctx->logits.capacity();
|
|
||||||
const size_t logits_size = ctx->logits.size();
|
|
||||||
const size_t embedding_size = ctx->embedding.size();
|
|
||||||
const size_t kv_size = llama_get_kv_cache_size(ctx);
|
|
||||||
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
|
||||||
|
|
||||||
uint8_t * out = dest;
|
|
||||||
memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
|
|
||||||
memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
|
|
||||||
memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
|
|
||||||
memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
|
|
||||||
if (logits_size) {
|
|
||||||
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
|
|
||||||
}
|
|
||||||
out += logits_capacity * sizeof(float);
|
|
||||||
memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
|
|
||||||
if (embedding_size) {
|
|
||||||
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
|
|
||||||
}
|
|
||||||
memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
|
|
||||||
memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
|
|
||||||
if (kv_size) {
|
|
||||||
memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
|
|
||||||
}
|
|
||||||
const size_t written = out - dest;
|
|
||||||
const size_t expected = llama_get_state_size(ctx);
|
|
||||||
LLAMA_ASSERT(written == expected);
|
|
||||||
return written;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets the state reading from the specified source address
|
|
||||||
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|
||||||
size_t rng_size;
|
|
||||||
char rng_buf[64*1024];
|
|
||||||
std::stringstream rng_ss;
|
|
||||||
|
|
||||||
const uint8_t * in = src;
|
|
||||||
memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
|
|
||||||
memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
|
|
||||||
rng_ss.str(std::string(&rng_buf[0], rng_size));
|
|
||||||
rng_ss >> ctx->rng;
|
|
||||||
LLAMA_ASSERT(rng_ss.fail() == false);
|
|
||||||
|
|
||||||
size_t logits_capacity;
|
|
||||||
size_t logits_size;
|
|
||||||
size_t embedding_size;
|
|
||||||
size_t kv_size;
|
|
||||||
int kv_ntok;
|
|
||||||
|
|
||||||
memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
|
|
||||||
memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
|
|
||||||
LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
|
|
||||||
if (logits_size) {
|
|
||||||
ctx->logits.resize(logits_size);
|
|
||||||
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
|
|
||||||
}
|
|
||||||
in += logits_capacity * sizeof(float);
|
|
||||||
memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
|
|
||||||
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
|
|
||||||
if (embedding_size) {
|
|
||||||
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
|
|
||||||
in += embedding_size * sizeof(float);
|
|
||||||
}
|
|
||||||
memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
|
|
||||||
memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
|
|
||||||
if (kv_size) {
|
|
||||||
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
|
|
||||||
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
|
||||||
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
|
||||||
memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
|
|
||||||
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
|
||||||
ctx->model.kv_self.v->data = v_data;
|
|
||||||
in += kv_size;
|
|
||||||
}
|
|
||||||
ctx->model.kv_self.n = kv_ntok;
|
|
||||||
const size_t nread = in - src;
|
|
||||||
const size_t expected = llama_get_state_size(ctx);
|
|
||||||
LLAMA_ASSERT(nread == expected);
|
|
||||||
return nread;
|
|
||||||
}
|
|
||||||
|
14
llama.h
14
llama.h
@ -112,23 +112,9 @@ extern "C" {
|
|||||||
const char * path_base_model,
|
const char * path_base_model,
|
||||||
int n_threads);
|
int n_threads);
|
||||||
|
|
||||||
// Returns the KV cache that will contain the context for the
|
|
||||||
// ongoing prediction with the model.
|
|
||||||
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
|
|
||||||
|
|
||||||
// Returns the size of the KV cache
|
|
||||||
LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
|
|
||||||
|
|
||||||
// Returns the number of tokens in the KV cache
|
// Returns the number of tokens in the KV cache
|
||||||
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
|
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
|
||||||
|
|
||||||
// Sets the KV cache containing the current context for the model
|
|
||||||
LLAMA_API void llama_set_kv_cache(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
const uint8_t * kv_cache,
|
|
||||||
size_t n_size,
|
|
||||||
int n_token_count);
|
|
||||||
|
|
||||||
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
|
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
|
||||||
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
|
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user