kv_cache : minor

This commit is contained in:
Georgi Gerganov 2025-01-14 11:56:53 +02:00
parent 422ecaf52c
commit d8b9013108
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 47 additions and 27 deletions

View File

@ -73,17 +73,22 @@ bool llama_kv_cache::init(
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
const char * dev_name = "CPU";
ggml_backend_buffer_type_t buft;
if (offload) {
auto * dev = model.dev_layer(i);
buft = ggml_backend_dev_buffer_type(dev);
dev_name = ggml_backend_dev_name(dev);
} else {
buft = ggml_backend_cpu_buffer_type();
}
ggml_context * ctx = ctx_for_buft(buft);
LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
return false;
@ -134,14 +139,13 @@ size_t llama_kv_cache::total_size() const {
return size;
}
// TODO: better data structures to reduce the cost of this operation
llama_pos llama_kv_cache::max_pos() const {
llama_pos max_pos = -1;
llama_pos llama_kv_cache::pos_max() const {
llama_pos pos_max = -1;
for (const auto & cell : cells) {
max_pos = std::max(max_pos, cell.pos);
pos_max = std::max(pos_max, cell.pos);
}
return max_pos;
return pos_max;
}
void llama_kv_cache::clear() {
@ -672,6 +676,26 @@ uint32_t llama_kv_cache::cell_max() const {
return 0;
}
size_t llama_kv_cache::size_k_bytes() const {
size_t size_k_bytes = 0;
for (const auto & k : k_l) {
size_k_bytes += ggml_nbytes(k);
}
return size_k_bytes;
}
size_t llama_kv_cache::size_v_bytes() const {
size_t size_v_bytes = 0;
for (const auto & v : v_l) {
size_v_bytes += ggml_nbytes(v);
}
return size_v_bytes;
}
void llama_kv_cache_clear(llama_kv_cache * kv) {
kv->clear();
}

View File

@ -61,17 +61,11 @@ struct llama_kv_cache {
// computed before each graph build
uint32_t n = 0;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// TODO: become constructor
bool init(
const llama_model & model,
@ -86,7 +80,7 @@ struct llama_kv_cache {
size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos max_pos() const;
llama_pos pos_max() const;
void clear();
@ -112,6 +106,16 @@ struct llama_kv_cache {
// find how many cells are currently in use
uint32_t cell_max() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
private:
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
};
//

View File

@ -1973,7 +1973,7 @@ struct llm_build_context {
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@ -8456,7 +8456,7 @@ static int llama_decode_impl(
}
// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;
@ -8792,7 +8792,7 @@ static int llama_encode_impl(
}
// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens = batch.n_tokens;
@ -9689,16 +9689,8 @@ struct llama_context * llama_init_from_model(
}
{
size_t memory_size_k = 0;
size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
}
for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
}
const size_t memory_size_k = ctx->kv_self.size_k_bytes();
const size_t memory_size_v = ctx->kv_self.size_v_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),