mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
llama : KV cache view API + better KV cache management (#4170)
* llama : keep track of used KV cells + better KV cache management * llama : zero KV cache used upon clear ggml-ci * llama : allow exporting a view of the KV cache (#4180) * Allow exporting a view of the KV cache * Allow dumping the sequences per cell in common * Track max contiguous cells value and position as well * Fix max contiguous empty cells index calculation Make dump functions deal with lengths or sequences counts > 10 better * Fix off by one error in dump_kv_cache_view * Add doc comments for KV cache view functions Eliminate cell sequence struct; use llama_seq_id directly Minor cleanups * common : add -dkvc arg for enabling kv cache dumps --------- Co-authored-by: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com>
This commit is contained in:
parent
d103d935c0
commit
6b0a7420d0
@ -12,6 +12,7 @@
|
|||||||
#include <regex>
|
#include <regex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
@ -495,6 +496,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.chatml = true;
|
params.chatml = true;
|
||||||
} else if (arg == "--infill") {
|
} else if (arg == "--infill") {
|
||||||
params.infill = true;
|
params.infill = true;
|
||||||
|
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
|
||||||
|
params.dump_kv_cache = true;
|
||||||
} else if (arg == "--multiline-input") {
|
} else if (arg == "--multiline-input") {
|
||||||
params.multiline_input = true;
|
params.multiline_input = true;
|
||||||
} else if (arg == "--simple-io") {
|
} else if (arg == "--simple-io") {
|
||||||
@ -835,6 +838,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
#endif
|
#endif
|
||||||
printf(" --verbose-prompt print prompt before generation\n");
|
printf(" --verbose-prompt print prompt before generation\n");
|
||||||
|
printf(" -dkvc, --dump-kv-cache\n");
|
||||||
|
printf(" verbose print of the KV cache\n");
|
||||||
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
|
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
|
||||||
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||||
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
|
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
|
||||||
@ -1386,3 +1391,77 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
|
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
|
||||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// KV cache utils
|
||||||
|
//
|
||||||
|
|
||||||
|
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
|
||||||
|
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
|
||||||
|
|
||||||
|
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
|
||||||
|
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
||||||
|
|
||||||
|
llama_kv_cache_view_cell * c_curr = view.cells;
|
||||||
|
llama_seq_id * cs_curr = view.cells_sequences;
|
||||||
|
|
||||||
|
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
|
||||||
|
if (i % row_size == 0) {
|
||||||
|
printf("\n%5d: ", i);
|
||||||
|
}
|
||||||
|
int seq_count = 0;
|
||||||
|
for (int j = 0; j < view.n_max_seq; j++) {
|
||||||
|
if (cs_curr[j] >= 0) { seq_count++; }
|
||||||
|
}
|
||||||
|
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n=== Done dumping\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
|
||||||
|
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
||||||
|
|
||||||
|
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
|
||||||
|
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
||||||
|
|
||||||
|
std::unordered_map<llama_seq_id, size_t> seqs;
|
||||||
|
llama_kv_cache_view_cell * c_curr = view.cells;
|
||||||
|
llama_seq_id * cs_curr = view.cells_sequences;
|
||||||
|
|
||||||
|
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
|
||||||
|
for (int j = 0; j < view.n_max_seq; j++) {
|
||||||
|
if (cs_curr[j] < 0) { continue; }
|
||||||
|
if (seqs.find(cs_curr[j]) == seqs.end()) {
|
||||||
|
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
||||||
|
seqs[cs_curr[j]] = seqs.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("=== Sequence legend: ");
|
||||||
|
for (const auto & it : seqs) {
|
||||||
|
printf("%zu=%d, ", it.second, it.first);
|
||||||
|
}
|
||||||
|
printf("'+'=other sequence ids");
|
||||||
|
|
||||||
|
c_curr = view.cells;
|
||||||
|
cs_curr = view.cells_sequences;
|
||||||
|
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
|
||||||
|
if (i % row_size == 0) {
|
||||||
|
printf("\n%5d: ", i);
|
||||||
|
}
|
||||||
|
for (int j = 0; j < view.n_max_seq; j++) {
|
||||||
|
if (cs_curr[j] >= 0) {
|
||||||
|
const auto & it = seqs.find(cs_curr[j]);
|
||||||
|
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
|
||||||
|
} else {
|
||||||
|
putchar('.');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
putchar(' ');
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n=== Done dumping\n");
|
||||||
|
}
|
||||||
|
@ -122,6 +122,7 @@ struct gpt_params {
|
|||||||
bool numa = false; // attempt optimizations that help on some NUMA systems
|
bool numa = false; // attempt optimizations that help on some NUMA systems
|
||||||
bool verbose_prompt = false; // print prompt tokens before generation
|
bool verbose_prompt = false; // print prompt tokens before generation
|
||||||
bool infill = false; // use infill mode
|
bool infill = false; // use infill mode
|
||||||
|
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
|
||||||
|
|
||||||
// multimodal models (see examples/llava)
|
// multimodal models (see examples/llava)
|
||||||
std::string mmproj = ""; // path to multimodal projector
|
std::string mmproj = ""; // path to multimodal projector
|
||||||
@ -218,3 +219,13 @@ std::string get_sortable_timestamp();
|
|||||||
void dump_non_result_info_yaml(
|
void dump_non_result_info_yaml(
|
||||||
FILE * stream, const gpt_params & params, const llama_context * lctx,
|
FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
|
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
|
||||||
|
|
||||||
|
//
|
||||||
|
// KV cache utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// Dump the KV cache view with the number of sequences per cell.
|
||||||
|
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
|
||||||
|
|
||||||
|
// Dump the KV cache view showing individual sequences in each cell (long output).
|
||||||
|
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
||||||
|
@ -113,6 +113,8 @@ int main(int argc, char ** argv) {
|
|||||||
// insert new requests as soon as the previous one is done
|
// insert new requests as soon as the previous one is done
|
||||||
const bool cont_batching = params.cont_batching;
|
const bool cont_batching = params.cont_batching;
|
||||||
|
|
||||||
|
const bool dump_kv_cache = params.dump_kv_cache;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("parallel", "log"));
|
log_set_target(log_filename_generator("parallel", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
@ -172,6 +174,8 @@ int main(int argc, char ** argv) {
|
|||||||
int32_t n_total_gen = 0;
|
int32_t n_total_gen = 0;
|
||||||
int32_t n_cache_miss = 0;
|
int32_t n_cache_miss = 0;
|
||||||
|
|
||||||
|
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);
|
||||||
|
|
||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
|
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
|
||||||
@ -201,6 +205,11 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("Processing requests ...\n\n");
|
LOG_TEE("Processing requests ...\n\n");
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
if (dump_kv_cache) {
|
||||||
|
llama_kv_cache_view_update(ctx, &kvc_view);
|
||||||
|
dump_kv_cache_view_seqs(kvc_view, 40);
|
||||||
|
}
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
// decode any currently ongoing sequences
|
// decode any currently ongoing sequences
|
||||||
|
128
llama.cpp
128
llama.cpp
@ -1280,6 +1280,7 @@ struct llama_kv_cache {
|
|||||||
// cannot be freely changed after a slot has been allocated.
|
// cannot be freely changed after a slot has been allocated.
|
||||||
uint32_t head = 0;
|
uint32_t head = 0;
|
||||||
uint32_t size = 0;
|
uint32_t size = 0;
|
||||||
|
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||||
|
|
||||||
// computed before each graph build
|
// computed before each graph build
|
||||||
uint32_t n = 0;
|
uint32_t n = 0;
|
||||||
@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(
|
|||||||
|
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
cache.size = n_ctx;
|
cache.size = n_ctx;
|
||||||
|
cache.used = 0;
|
||||||
|
|
||||||
cache.cells.clear();
|
cache.cells.clear();
|
||||||
cache.cells.resize(n_ctx);
|
cache.cells.resize(n_ctx);
|
||||||
@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cache.used += n_tokens;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
|
|||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
}
|
}
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
|
cache.used = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_seq_rm(
|
static void llama_kv_cache_seq_rm(
|
||||||
@ -1647,6 +1652,9 @@ static void llama_kv_cache_seq_rm(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (cache.cells[i].seq_id.empty()) {
|
if (cache.cells[i].seq_id.empty()) {
|
||||||
|
// keep count of the number of used cells
|
||||||
|
if (cache.cells[i].pos >= 0) cache.used--;
|
||||||
|
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
if (new_head == cache.size) new_head = i;
|
if (new_head == cache.size) new_head = i;
|
||||||
}
|
}
|
||||||
@ -1654,7 +1662,7 @@ static void llama_kv_cache_seq_rm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
if (new_head != cache.size) cache.head = new_head;
|
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_seq_cp(
|
static void llama_kv_cache_seq_cp(
|
||||||
@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
|
|||||||
|
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
if (!cache.cells[i].has_seq_id(seq_id)) {
|
if (!cache.cells[i].has_seq_id(seq_id)) {
|
||||||
|
if (cache.cells[i].pos >= 0) cache.used--;
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
if (new_head == cache.size) new_head = i;
|
if (new_head == cache.size) new_head = i;
|
||||||
@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
if (new_head != cache.size) cache.head = new_head;
|
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_seq_shift(
|
static void llama_kv_cache_seq_shift(
|
||||||
@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
|
|||||||
cache.cells[i].delta += delta;
|
cache.cells[i].delta += delta;
|
||||||
|
|
||||||
if (cache.cells[i].pos < 0) {
|
if (cache.cells[i].pos < 0) {
|
||||||
|
if (!cache.cells[i].seq_id.empty()) cache.used--;
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
if (new_head == cache.size) new_head = i;
|
if (new_head == cache.size) new_head = i;
|
||||||
@ -5469,6 +5479,12 @@ static int llama_decode_internal(
|
|||||||
batch.seq_id = seq_id_arr.data();
|
batch.seq_id = seq_id_arr.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if we have enough unused cells before the current head ->
|
||||||
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
|
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||||
|
kv_self.head = 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -5479,7 +5495,7 @@ static int llama_decode_internal(
|
|||||||
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
|
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
|
||||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
|
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
|
||||||
|
|
||||||
//printf("kv_self.n = %d\n", kv_self.n);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
|
||||||
ggml_allocr_reset(lctx.alloc);
|
ggml_allocr_reset(lctx.alloc);
|
||||||
|
|
||||||
@ -8789,8 +8805,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
|
||||||
|
struct llama_kv_cache_view result = {
|
||||||
|
/*.n_cells = */ 0,
|
||||||
|
/*.n_max_seq = */ n_max_seq,
|
||||||
|
/*.token_count = */ 0,
|
||||||
|
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
|
||||||
|
/*.max_contiguous = */ 0,
|
||||||
|
/*.max_contiguous_idx = */ -1,
|
||||||
|
/*.cells = */ nullptr,
|
||||||
|
/*.cells_sequences = */ nullptr,
|
||||||
|
};
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
|
||||||
|
if (view->cells != nullptr) {
|
||||||
|
free(view->cells);
|
||||||
|
view->cells = nullptr;
|
||||||
|
}
|
||||||
|
if (view->cells_sequences != nullptr) {
|
||||||
|
free(view->cells_sequences);
|
||||||
|
view->cells_sequences = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
|
||||||
|
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
|
||||||
|
view->n_cells = int32_t(ctx->kv_self.size);
|
||||||
|
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
|
||||||
|
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
||||||
|
view->cells = (struct llama_kv_cache_view_cell *)p;
|
||||||
|
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
|
||||||
|
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
|
||||||
|
view->cells_sequences = (llama_seq_id *)p;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
|
||||||
|
llama_kv_cache_view_cell * c_curr = view->cells;
|
||||||
|
llama_seq_id * cs_curr = view->cells_sequences;
|
||||||
|
int32_t used_cells = 0;
|
||||||
|
int32_t token_count = 0;
|
||||||
|
int32_t curr_contig_idx = -1;
|
||||||
|
uint32_t max_contig = 0;
|
||||||
|
int32_t max_contig_idx = -1;
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
|
||||||
|
const size_t curr_size = kv_cells[i].seq_id.size();
|
||||||
|
token_count += curr_size;
|
||||||
|
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
||||||
|
|
||||||
|
if (curr_size > 0) {
|
||||||
|
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
|
||||||
|
max_contig = i - curr_contig_idx;
|
||||||
|
max_contig_idx = curr_contig_idx;
|
||||||
|
}
|
||||||
|
curr_contig_idx = -1;
|
||||||
|
} else if (curr_contig_idx < 0) {
|
||||||
|
curr_contig_idx = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
int seq_idx = 0;
|
||||||
|
for (const llama_seq_id it : kv_cells[i].seq_id) {
|
||||||
|
if (seq_idx >= view->n_max_seq) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
cs_curr[seq_idx] = it;
|
||||||
|
seq_idx++;
|
||||||
|
}
|
||||||
|
if (seq_idx != 0) {
|
||||||
|
used_cells++;
|
||||||
|
}
|
||||||
|
for (; seq_idx < view->n_max_seq; seq_idx++) {
|
||||||
|
cs_curr[seq_idx] = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
|
||||||
|
max_contig_idx = curr_contig_idx;
|
||||||
|
max_contig = kv_cells.size() - curr_contig_idx;
|
||||||
|
}
|
||||||
|
view->max_contiguous = max_contig;
|
||||||
|
view->max_contiguous_idx = max_contig_idx;
|
||||||
|
view->token_count = token_count;
|
||||||
|
view->used_cells = used_cells;
|
||||||
|
if (uint32_t(used_cells) != ctx->kv_self.used) {
|
||||||
|
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
||||||
|
__func__, ctx->kv_self.used, used_cells);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
||||||
return ctx->kv_self.head;
|
int result = 0;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
|
||||||
|
result += ctx->kv_self.cells[i].seq_id.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
|
||||||
|
return ctx->kv_self.used;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_clear(struct llama_context * ctx) {
|
void llama_kv_cache_clear(struct llama_context * ctx) {
|
||||||
@ -8960,10 +9075,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||||||
const size_t kv_buf_size = kv_self.buf.size;
|
const size_t kv_buf_size = kv_self.buf.size;
|
||||||
const uint32_t kv_head = kv_self.head;
|
const uint32_t kv_head = kv_self.head;
|
||||||
const uint32_t kv_size = kv_self.size;
|
const uint32_t kv_size = kv_self.size;
|
||||||
|
const uint32_t kv_used = kv_self.used;
|
||||||
|
|
||||||
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
|
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
|
||||||
data_ctx->write(&kv_head, sizeof(kv_head));
|
data_ctx->write(&kv_head, sizeof(kv_head));
|
||||||
data_ctx->write(&kv_size, sizeof(kv_size));
|
data_ctx->write(&kv_size, sizeof(kv_size));
|
||||||
|
data_ctx->write(&kv_used, sizeof(kv_used));
|
||||||
|
|
||||||
if (kv_buf_size) {
|
if (kv_buf_size) {
|
||||||
const size_t elt_size = ggml_element_size(kv_self.k);
|
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||||
@ -9086,10 +9203,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
|||||||
size_t kv_buf_size;
|
size_t kv_buf_size;
|
||||||
uint32_t kv_head;
|
uint32_t kv_head;
|
||||||
uint32_t kv_size;
|
uint32_t kv_size;
|
||||||
|
uint32_t kv_used;
|
||||||
|
|
||||||
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
|
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
|
||||||
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
|
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
|
||||||
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
||||||
|
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
|
||||||
|
|
||||||
if (kv_buf_size) {
|
if (kv_buf_size) {
|
||||||
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
|
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
|
||||||
@ -9124,6 +9243,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
|||||||
|
|
||||||
ctx->kv_self.head = kv_head;
|
ctx->kv_self.head = kv_head;
|
||||||
ctx->kv_self.size = kv_size;
|
ctx->kv_self.size = kv_size;
|
||||||
|
ctx->kv_self.used = kv_used;
|
||||||
|
|
||||||
ctx->kv_self.cells.resize(kv_size);
|
ctx->kv_self.cells.resize(kv_size);
|
||||||
|
|
||||||
|
57
llama.h
57
llama.h
@ -361,9 +361,60 @@ extern "C" {
|
|||||||
// KV cache
|
// KV cache
|
||||||
//
|
//
|
||||||
|
|
||||||
// Returns the number of tokens in the KV cache
|
// Information associated with an individual cell in the KV cache view.
|
||||||
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
struct llama_kv_cache_view_cell {
|
||||||
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
// The position for this cell. Takes KV cache shifts into account.
|
||||||
|
// May be negative if the cell is not populated.
|
||||||
|
llama_pos pos;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An updateable view of the KV cache.
|
||||||
|
struct llama_kv_cache_view {
|
||||||
|
// Number of KV cache cells. This will be the same as the context size.
|
||||||
|
int32_t n_cells;
|
||||||
|
|
||||||
|
// Maximum number of sequences that can exist in a cell. It's not an error
|
||||||
|
// if there are more sequences in a cell than this value, however they will
|
||||||
|
// not be visible in the view cells_sequences.
|
||||||
|
int32_t n_max_seq;
|
||||||
|
|
||||||
|
// Number of tokens in the cache. For example, if there are two populated
|
||||||
|
// cells, the first with 1 sequence id in it and the second with 2 sequence
|
||||||
|
// ids then you'll have 3 tokens.
|
||||||
|
int32_t token_count;
|
||||||
|
|
||||||
|
// Number of populated cache cells.
|
||||||
|
int32_t used_cells;
|
||||||
|
|
||||||
|
// Maximum contiguous empty slots in the cache.
|
||||||
|
int32_t max_contiguous;
|
||||||
|
|
||||||
|
// Index to the start of the max_contiguous slot range. Can be negative
|
||||||
|
// when cache is full.
|
||||||
|
int32_t max_contiguous_idx;
|
||||||
|
|
||||||
|
// Information for an individual cell.
|
||||||
|
struct llama_kv_cache_view_cell * cells;
|
||||||
|
|
||||||
|
// The sequences for each cell. There will be n_max_seq items per cell.
|
||||||
|
llama_seq_id * cells_sequences;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create an empty KV cache view. (use only for debugging purposes)
|
||||||
|
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
|
||||||
|
|
||||||
|
// Free a KV cache view. (use only for debugging purposes)
|
||||||
|
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
||||||
|
|
||||||
|
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||||
|
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||||
|
|
||||||
|
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||||
|
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||||
|
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||||
|
LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Clear the KV cache
|
// Clear the KV cache
|
||||||
LLAMA_API void llama_kv_cache_clear(
|
LLAMA_API void llama_kv_cache_clear(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user