llama : update session save/load

This commit is contained in:
Georgi Gerganov 2023-12-03 21:10:16 +02:00
parent e262947d43
commit 66aaac9867
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 50 additions and 43 deletions

View File

@ -1563,6 +1563,8 @@ static bool llama_kv_cache_init(
const int i_gpu_start = n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start);
GGML_UNUSED(offload);
for (int i = 0; i < (int) n_layer; i++) {
ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, wtype, n_embd*n_ctx);
ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, wtype, n_embd*n_ctx);
@ -5406,7 +5408,7 @@ static struct ggml_cgraph * llama_build_graph(
//
// TODO: will be removed with backend v2
#define LLAMA_OFFLOAD_DEBUG
//#define LLAMA_OFFLOAD_DEBUG
if (!do_offload) {
return;
@ -9297,40 +9299,45 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) {
#pragma message("TODO: implement KV cache saving")
#if 0
const size_t elt_size = ggml_element_size(kv_self.k);
const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
kout3d->data = kout3d_data.data();
std::vector<std::vector<uint8_t>> kout2d_data(n_layer);
std::vector<std::vector<uint8_t>> vout2d_data(n_layer);
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
vout3d->data = vout3d_data.data();
for (int il = 0; il < (int) n_layer; ++il) {
ggml_tensor * kout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
kout2d_data[il].resize(ggml_nbytes(kout2d));
kout2d->data = kout2d_data[il].data();
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_head, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
ggml_tensor * vout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
vout2d_data[il].resize(ggml_nbytes(vout2d));
vout2d->data = vout2d_data[il].data();
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
n_embd, kv_head,
elt_size*n_embd, 0);
ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
kv_head, n_embd,
elt_size*n_ctx, 0);
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v2d, vout2d));
}
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
// our data is now in the kout3d_data and vout3d_data buffers
// our data is now in the kout2d_data and vout2d_data buffers
// write them to file
data_ctx->write(kout3d_data.data(), kout3d_data.size());
data_ctx->write(vout3d_data.data(), vout3d_data.size());
#endif
for (uint32_t il = 0; il < n_layer; ++il) {
data_ctx->write(kout2d_data[il].data(), kout2d_data[il].size());
data_ctx->write(vout2d_data[il].data(), vout2d_data[il].size());
}
}
for (uint32_t i = 0; i < kv_size; ++i) {
@ -9430,35 +9437,35 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
if (kv_buf_size) {
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
#pragma message("TODO: implement KV cache loading")
#if 0
const size_t elt_size = ggml_element_size(kv_self.k);
const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
kin3d->data = (void *) inp;
inp += ggml_nbytes(kin3d);
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * kin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
kin2d->data = (void *) inp;
inp += ggml_nbytes(kin2d);
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
vin3d->data = (void *) inp;
inp += ggml_nbytes(vin3d);
ggml_tensor * vin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
vin2d->data = (void *) inp;
inp += ggml_nbytes(vin2d);
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_head, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
n_embd, kv_head,
elt_size*n_embd, 0);
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
kv_head, n_embd,
elt_size*n_ctx, 0);
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d, k2d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin2d, v2d));
}
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
#endif
}
ctx->kv_self.head = kv_head;

View File

@ -42,7 +42,7 @@
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 2
#define LLAMA_SESSION_VERSION 3
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.