mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-19 08:20:10 +01:00
llama : keep track of used KV cells + better KV cache management
This commit is contained in:
parent
8e672efe63
commit
79cb8f0040
38
llama.cpp
38
llama.cpp
@ -1280,6 +1280,7 @@ struct llama_kv_cache {
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(
|
||||
|
||||
cache.head = 0;
|
||||
cache.size = n_ctx;
|
||||
cache.used = 0;
|
||||
|
||||
cache.cells.clear();
|
||||
cache.cells.resize(n_ctx);
|
||||
@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
|
||||
}
|
||||
}
|
||||
|
||||
cache.used += n_tokens;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1647,6 +1651,9 @@ static void llama_kv_cache_seq_rm(
|
||||
continue;
|
||||
}
|
||||
if (cache.cells[i].seq_id.empty()) {
|
||||
// keep count of the number of used cells
|
||||
if (cache.cells[i].pos >= 0) cache.used--;
|
||||
|
||||
cache.cells[i].pos = -1;
|
||||
if (new_head == cache.size) new_head = i;
|
||||
}
|
||||
@ -1654,7 +1661,7 @@ static void llama_kv_cache_seq_rm(
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != cache.size) cache.head = new_head;
|
||||
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_cp(
|
||||
@ -1680,6 +1687,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (!cache.cells[i].has_seq_id(seq_id)) {
|
||||
if (cache.cells[i].pos >= 0) cache.used--;
|
||||
cache.cells[i].pos = -1;
|
||||
cache.cells[i].seq_id.clear();
|
||||
if (new_head == cache.size) new_head = i;
|
||||
@ -1690,7 +1698,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != cache.size) cache.head = new_head;
|
||||
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_shift(
|
||||
@ -1711,6 +1719,7 @@ static void llama_kv_cache_seq_shift(
|
||||
cache.cells[i].delta += delta;
|
||||
|
||||
if (cache.cells[i].pos < 0) {
|
||||
if (!cache.cells[i].seq_id.empty()) cache.used--;
|
||||
cache.cells[i].pos = -1;
|
||||
cache.cells[i].seq_id.clear();
|
||||
if (new_head == cache.size) new_head = i;
|
||||
@ -5469,6 +5478,12 @@ static int llama_decode_internal(
|
||||
batch.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
return 1;
|
||||
}
|
||||
@ -5479,7 +5494,7 @@ static int llama_decode_internal(
|
||||
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
|
||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
|
||||
|
||||
//printf("kv_self.n = %d\n", kv_self.n);
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
ggml_allocr_reset(lctx.alloc);
|
||||
|
||||
@ -8790,7 +8805,17 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
|
||||
}
|
||||
|
||||
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
||||
return ctx->kv_self.head;
|
||||
int result = 0;
|
||||
|
||||
for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
|
||||
result += ctx->kv_self.cells[i].seq_id.size();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
|
||||
return ctx->kv_self.used;
|
||||
}
|
||||
|
||||
void llama_kv_cache_clear(struct llama_context * ctx) {
|
||||
@ -8960,10 +8985,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||
const size_t kv_buf_size = kv_self.buf.size;
|
||||
const uint32_t kv_head = kv_self.head;
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
const uint32_t kv_used = kv_self.used;
|
||||
|
||||
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
|
||||
data_ctx->write(&kv_head, sizeof(kv_head));
|
||||
data_ctx->write(&kv_size, sizeof(kv_size));
|
||||
data_ctx->write(&kv_used, sizeof(kv_used));
|
||||
|
||||
if (kv_buf_size) {
|
||||
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||
@ -9086,10 +9113,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
||||
size_t kv_buf_size;
|
||||
uint32_t kv_head;
|
||||
uint32_t kv_size;
|
||||
uint32_t kv_used;
|
||||
|
||||
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
|
||||
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
|
||||
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
||||
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
|
||||
|
||||
if (kv_buf_size) {
|
||||
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
|
||||
@ -9124,6 +9153,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
||||
|
||||
ctx->kv_self.head = kv_head;
|
||||
ctx->kv_self.size = kv_size;
|
||||
ctx->kv_self.used = kv_used;
|
||||
|
||||
ctx->kv_self.cells.resize(kv_size);
|
||||
|
||||
|
9
llama.h
9
llama.h
@ -361,9 +361,12 @@ extern "C" {
|
||||
// KV cache
|
||||
//
|
||||
|
||||
// Returns the number of tokens in the KV cache
|
||||
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||
|
||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||
LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
||||
|
||||
// Clear the KV cache
|
||||
LLAMA_API void llama_kv_cache_clear(
|
||||
|
Loading…
Reference in New Issue
Block a user