mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-30 22:03:03 +01:00
kv_cache : functions -> members
ggml-ci
This commit is contained in:
parent
3d5908092f
commit
53a7c20d89
@ -1169,7 +1169,7 @@ struct llama_data_read {
|
|||||||
}
|
}
|
||||||
batch.n_seq_id[0] = 1;
|
batch.n_seq_id[0] = 1;
|
||||||
batch.seq_id[0] = &dest_seq_id;
|
batch.seq_id[0] = &dest_seq_id;
|
||||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
if (!kv_self.find_slot(batch)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -11,41 +11,35 @@
|
|||||||
|
|
||||||
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
||||||
|
|
||||||
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
|
bool llama_kv_cache::init(
|
||||||
// the FA kernels require padding to avoid extra runtime boundary checks
|
const llama_model & model,
|
||||||
return cparams.flash_attn ? 256u : 32u;
|
const llama_cparams & cparams,
|
||||||
}
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
bool llama_kv_cache_init(
|
uint32_t kv_size,
|
||||||
struct llama_kv_cache & cache,
|
bool offload) {
|
||||||
const llama_model & model,
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
ggml_type type_k,
|
|
||||||
ggml_type type_v,
|
|
||||||
uint32_t kv_size,
|
|
||||||
bool offload) {
|
|
||||||
const struct llama_hparams & hparams = model.hparams;
|
const struct llama_hparams & hparams = model.hparams;
|
||||||
|
|
||||||
const int32_t n_layer = hparams.n_layer;
|
const int32_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
cache.has_shift = false;
|
has_shift = false;
|
||||||
|
|
||||||
cache.recurrent = llama_model_is_recurrent(&model);
|
recurrent = llama_model_is_recurrent(&model);
|
||||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
v_trans = !recurrent && !cparams.flash_attn;
|
||||||
cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
|
can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
|
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
|
||||||
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
|
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
|
||||||
|
|
||||||
cache.head = 0;
|
head = 0;
|
||||||
cache.size = kv_size;
|
size = kv_size;
|
||||||
cache.used = 0;
|
used = 0;
|
||||||
|
|
||||||
cache.type_k = type_k;
|
type_k = type_k;
|
||||||
cache.type_v = type_v;
|
type_v = type_v;
|
||||||
|
|
||||||
cache.cells.clear();
|
cells.clear();
|
||||||
cache.cells.resize(kv_size);
|
cells.resize(kv_size);
|
||||||
|
|
||||||
// create a context for each buffer type
|
// create a context for each buffer type
|
||||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||||
@ -57,19 +51,23 @@ bool llama_kv_cache_init(
|
|||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_context * ctx = ggml_init(params);
|
ggml_context * ctx = ggml_init(params);
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx_map[buft] = ctx;
|
ctx_map[buft] = ctx;
|
||||||
cache.ctxs.emplace_back(ctx);
|
ctxs.emplace_back(ctx);
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
return it->second;
|
return it->second;
|
||||||
};
|
};
|
||||||
|
|
||||||
cache.k_l.reserve(n_layer);
|
k_l.reserve(n_layer);
|
||||||
cache.v_l.reserve(n_layer);
|
v_l.reserve(n_layer);
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; i++) {
|
for (int i = 0; i < n_layer; i++) {
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||||
@ -95,8 +93,8 @@ bool llama_kv_cache_init(
|
|||||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||||
ggml_format_name(k, "cache_k_l%d", i);
|
ggml_format_name(k, "cache_k_l%d", i);
|
||||||
ggml_format_name(v, "cache_v_l%d", i);
|
ggml_format_name(v, "cache_v_l%d", i);
|
||||||
cache.k_l.push_back(k);
|
k_l.push_back(k);
|
||||||
cache.v_l.push_back(v);
|
v_l.push_back(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||||
@ -111,20 +109,339 @@ bool llama_kv_cache_init(
|
|||||||
}
|
}
|
||||||
ggml_backend_buffer_clear(buf, 0);
|
ggml_backend_buffer_clear(buf, 0);
|
||||||
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||||
cache.bufs.emplace_back(buf);
|
bufs.emplace_back(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
int32_t llama_kv_cache::n_tokens() const {
|
||||||
struct llama_kv_cache & cache,
|
int32_t result = 0;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < size; i++) {
|
||||||
|
result += cells[i].seq_id.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_kv_cache::total_size() const {
|
||||||
|
size_t size = 0;
|
||||||
|
for (const auto & buf : bufs) {
|
||||||
|
size += ggml_backend_buffer_get_size(buf.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: better data structures to reduce the cost of this operation
|
||||||
|
llama_pos llama_kv_cache::max_pos() const {
|
||||||
|
llama_pos max_pos = -1;
|
||||||
|
for (const auto & cell : cells) {
|
||||||
|
max_pos = std::max(max_pos, cell.pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
return max_pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache::clear() {
|
||||||
|
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
||||||
|
cells[i].pos = -1;
|
||||||
|
cells[i].seq_id.clear();
|
||||||
|
cells[i].src = -1;
|
||||||
|
cells[i].tail = -1;
|
||||||
|
}
|
||||||
|
head = 0;
|
||||||
|
used = 0;
|
||||||
|
|
||||||
|
for (auto & buf : bufs) {
|
||||||
|
ggml_backend_buffer_clear(buf.get(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
|
uint32_t new_head = size;
|
||||||
|
|
||||||
|
if (p0 < 0) {
|
||||||
|
p0 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p1 < 0) {
|
||||||
|
p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
// models like Mamba or RWKV can't have a state partially erased
|
||||||
|
if (recurrent) {
|
||||||
|
if (seq_id >= (int64_t) size) {
|
||||||
|
// could be fatal
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (0 <= seq_id) {
|
||||||
|
int32_t & tail_id = cells[seq_id].tail;
|
||||||
|
if (tail_id >= 0) {
|
||||||
|
const llama_kv_cell & cell = cells[tail_id];
|
||||||
|
// partial intersection is invalid
|
||||||
|
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// invalidate tails which will be cleared
|
||||||
|
if (p0 <= cell.pos && cell.pos < p1) {
|
||||||
|
tail_id = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// seq_id is negative, then the range should include everything or nothing
|
||||||
|
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
if (cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||||
|
if (seq_id < 0) {
|
||||||
|
cells[i].seq_id.clear();
|
||||||
|
} else if (cells[i].has_seq_id(seq_id)) {
|
||||||
|
cells[i].seq_id.erase(seq_id);
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (cells[i].is_empty()) {
|
||||||
|
// keep count of the number of used cells
|
||||||
|
if (cells[i].pos >= 0) {
|
||||||
|
used--;
|
||||||
|
}
|
||||||
|
|
||||||
|
cells[i].pos = -1;
|
||||||
|
cells[i].src = -1;
|
||||||
|
|
||||||
|
if (new_head == size) {
|
||||||
|
new_head = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
|
if (new_head != size && new_head < head) {
|
||||||
|
head = new_head;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||||
|
if (seq_id_src == seq_id_dst) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p0 < 0) {
|
||||||
|
p0 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p1 < 0) {
|
||||||
|
p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recurrent) {
|
||||||
|
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
||||||
|
llama_kv_cell & tail_src = cells[seq_id_src];
|
||||||
|
llama_kv_cell & tail_dst = cells[seq_id_dst];
|
||||||
|
if (tail_dst.tail >= 0) {
|
||||||
|
// clear destination seq_id if it wasn't empty
|
||||||
|
llama_kv_cell & cell_dst = cells[tail_dst.tail];
|
||||||
|
|
||||||
|
cell_dst.seq_id.erase(seq_id_dst);
|
||||||
|
tail_dst.tail = -1;
|
||||||
|
if (cell_dst.seq_id.empty()) {
|
||||||
|
cell_dst.pos = -1;
|
||||||
|
cell_dst.delta = -1;
|
||||||
|
cell_dst.src = -1;
|
||||||
|
used -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tail_src.tail >= 0) {
|
||||||
|
llama_kv_cell & cell_src = cells[tail_src.tail];
|
||||||
|
|
||||||
|
cell_src.seq_id.insert(seq_id_dst);
|
||||||
|
tail_dst.tail = tail_src.tail;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise, this is the KV of a Transformer-like model
|
||||||
|
head = 0;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||||
|
cells[i].seq_id.insert(seq_id_dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
||||||
|
uint32_t new_head = size;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
if (recurrent && (llama_seq_id) i != seq_id) {
|
||||||
|
cells[i].tail = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!cells[i].has_seq_id(seq_id)) {
|
||||||
|
if (cells[i].pos >= 0) {
|
||||||
|
used--;
|
||||||
|
}
|
||||||
|
|
||||||
|
cells[i].pos = -1;
|
||||||
|
cells[i].src = -1;
|
||||||
|
cells[i].seq_id.clear();
|
||||||
|
|
||||||
|
if (new_head == size){
|
||||||
|
new_head = i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cells[i].seq_id.clear();
|
||||||
|
cells[i].seq_id.insert(seq_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
|
if (new_head != size && new_head < head) {
|
||||||
|
head = new_head;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||||
|
if (delta == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t new_head = size;
|
||||||
|
|
||||||
|
if (p0 < 0) {
|
||||||
|
p0 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p1 < 0) {
|
||||||
|
p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no range then return early to avoid looping over the
|
||||||
|
if (p0 == p1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recurrent) {
|
||||||
|
// for Mamba-like or RWKV models, only the pos needs to be shifted
|
||||||
|
if (0 <= seq_id && seq_id < (int64_t) size) {
|
||||||
|
const int32_t tail_id = cells[seq_id].tail;
|
||||||
|
if (tail_id >= 0) {
|
||||||
|
llama_kv_cell & cell = cells[tail_id];
|
||||||
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||||
|
cell.pos += delta;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||||
|
has_shift = true;
|
||||||
|
cells[i].pos += delta;
|
||||||
|
cells[i].delta += delta;
|
||||||
|
|
||||||
|
if (cells[i].pos < 0) {
|
||||||
|
if (!cells[i].is_empty()) {
|
||||||
|
used--;
|
||||||
|
}
|
||||||
|
cells[i].pos = -1;
|
||||||
|
cells[i].seq_id.clear();
|
||||||
|
if (new_head == size) {
|
||||||
|
new_head = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
|
// Otherwise we just start the next search from the beginning.
|
||||||
|
head = new_head != size ? new_head : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
|
if (d == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p0 < 0) {
|
||||||
|
p0 = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p1 < 0) {
|
||||||
|
p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no range then return early to avoid looping over the cache.
|
||||||
|
if (p0 == p1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recurrent) {
|
||||||
|
// for Mamba-like or RWKV models, only the pos needs to be changed
|
||||||
|
if (0 <= seq_id && seq_id < (int64_t) size) {
|
||||||
|
const int32_t tail_id = cells[seq_id].tail;
|
||||||
|
if (tail_id >= 0) {
|
||||||
|
llama_kv_cell & cell = cells[tail_id];
|
||||||
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||||
|
cell.pos /= d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||||
|
has_shift = true;
|
||||||
|
|
||||||
|
{
|
||||||
|
llama_pos p_old = cells[i].pos;
|
||||||
|
cells[i].pos /= d;
|
||||||
|
cells[i].delta += cells[i].pos - p_old;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) {
|
||||||
|
llama_pos result = 0;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
if (cells[i].has_seq_id(seq_id)) {
|
||||||
|
result = std::max(result, cells[i].pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache::defrag() {
|
||||||
|
if (!recurrent) {
|
||||||
|
do_defrag = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_kv_cache_slot_info llama_kv_cache::find_slot(
|
||||||
const struct llama_ubatch & ubatch) {
|
const struct llama_ubatch & ubatch) {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
const uint32_t n_seqs = ubatch.n_seqs;
|
const uint32_t n_seqs = ubatch.n_seqs;
|
||||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
|
||||||
if (cache.recurrent) {
|
if (recurrent) {
|
||||||
// For recurrent state architectures (like Mamba or RWKV),
|
// For recurrent state architectures (like Mamba or RWKV),
|
||||||
// each cache cell can store the state for a whole sequence.
|
// each cache cell can store the state for a whole sequence.
|
||||||
// A slot should be always be contiguous.
|
// A slot should be always be contiguous.
|
||||||
@ -132,7 +449,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
// can only process batches with an equal number of new tokens in each sequence
|
// can only process batches with an equal number of new tokens in each sequence
|
||||||
GGML_ASSERT(ubatch.equal_seqs);
|
GGML_ASSERT(ubatch.equal_seqs);
|
||||||
|
|
||||||
int32_t min = cache.size - 1;
|
int32_t min = size - 1;
|
||||||
int32_t max = 0;
|
int32_t max = 0;
|
||||||
|
|
||||||
// everything should fit if all seq_ids are smaller than the max
|
// everything should fit if all seq_ids are smaller than the max
|
||||||
@ -141,16 +458,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||||
|
|
||||||
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
|
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
||||||
// too big seq_id
|
// too big seq_id
|
||||||
// TODO: would it be possible to resize the cache instead?
|
// TODO: would it be possible to resize the cache instead?
|
||||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
|
||||||
return llama_kv_cache_slot_info_failed;
|
return llama_kv_cache_slot_info_failed;
|
||||||
}
|
}
|
||||||
if (j > 0) {
|
if (j > 0) {
|
||||||
llama_kv_cell & seq = cache.cells[seq_id];
|
llama_kv_cell & seq = cells[seq_id];
|
||||||
if (seq.tail >= 0) {
|
if (seq.tail >= 0) {
|
||||||
llama_kv_cell & cell = cache.cells[seq.tail];
|
llama_kv_cell & cell = cells[seq.tail];
|
||||||
// clear cells from seq_ids that become shared
|
// clear cells from seq_ids that become shared
|
||||||
// (should not normally happen, but let's handle it anyway)
|
// (should not normally happen, but let's handle it anyway)
|
||||||
cell.seq_id.erase(seq_id);
|
cell.seq_id.erase(seq_id);
|
||||||
@ -158,7 +475,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
if (cell.seq_id.empty()) {
|
if (cell.seq_id.empty()) {
|
||||||
cell.pos = -1;
|
cell.pos = -1;
|
||||||
cell.src = -1;
|
cell.src = -1;
|
||||||
cache.used -= 1;
|
used -= 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -168,9 +485,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
{
|
{
|
||||||
std::vector<int32_t> tails_verif;
|
std::vector<int32_t> tails_verif;
|
||||||
tails_verif.assign(cache.size, -1);
|
tails_verif.assign(size, -1);
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
llama_kv_cell & cell = cache.cells[i];
|
llama_kv_cell & cell = cells[i];
|
||||||
for (llama_seq_id seq_id : cell.seq_id) {
|
for (llama_seq_id seq_id : cell.seq_id) {
|
||||||
if (tails_verif[seq_id] != -1) {
|
if (tails_verif[seq_id] != -1) {
|
||||||
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
||||||
@ -178,20 +495,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
tails_verif[seq_id] = i;
|
tails_verif[seq_id] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
if (tails_verif[i] != cache.cells[i].tail) {
|
if (tails_verif[i] != cells[i].tail) {
|
||||||
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
|
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// find next empty cell
|
// find next empty cell
|
||||||
uint32_t next_empty_cell = cache.head;
|
uint32_t next_empty_cell = head;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||||
llama_kv_cell & cell = cache.cells[next_empty_cell];
|
llama_kv_cell & cell = cells[next_empty_cell];
|
||||||
if (cell.is_empty()) { break; }
|
if (cell.is_empty()) { break; }
|
||||||
next_empty_cell += 1;
|
next_empty_cell += 1;
|
||||||
}
|
}
|
||||||
@ -199,20 +516,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
// find usable cell range
|
// find usable cell range
|
||||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||||
llama_kv_cell & seq_meta = cache.cells[seq_id];
|
llama_kv_cell & seq_meta = cells[seq_id];
|
||||||
bool has_cell = false;
|
bool has_cell = false;
|
||||||
if (seq_meta.tail >= 0) {
|
if (seq_meta.tail >= 0) {
|
||||||
llama_kv_cell & cell = cache.cells[seq_meta.tail];
|
llama_kv_cell & cell = cells[seq_meta.tail];
|
||||||
GGML_ASSERT(cell.has_seq_id(seq_id));
|
GGML_ASSERT(cell.has_seq_id(seq_id));
|
||||||
// does this seq_id "own" the cell?
|
// does this seq_id "own" the cell?
|
||||||
if (cell.seq_id.size() == 1) { has_cell = true; }
|
if (cell.seq_id.size() == 1) { has_cell = true; }
|
||||||
}
|
}
|
||||||
if (!has_cell) {
|
if (!has_cell) {
|
||||||
llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
|
llama_kv_cell & empty_cell = cells[next_empty_cell];
|
||||||
GGML_ASSERT(empty_cell.is_empty());
|
GGML_ASSERT(empty_cell.is_empty());
|
||||||
// copy old tail into the empty cell
|
// copy old tail into the empty cell
|
||||||
if (seq_meta.tail >= 0) {
|
if (seq_meta.tail >= 0) {
|
||||||
llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
|
llama_kv_cell & orig_cell = cells[seq_meta.tail];
|
||||||
empty_cell.pos = orig_cell.pos;
|
empty_cell.pos = orig_cell.pos;
|
||||||
empty_cell.src = orig_cell.src;
|
empty_cell.src = orig_cell.src;
|
||||||
orig_cell.seq_id.erase(seq_id);
|
orig_cell.seq_id.erase(seq_id);
|
||||||
@ -222,9 +539,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
// find next empty cell
|
// find next empty cell
|
||||||
if (s + 1 < n_seqs) {
|
if (s + 1 < n_seqs) {
|
||||||
next_empty_cell += 1;
|
next_empty_cell += 1;
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
|
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||||
llama_kv_cell & cell = cache.cells[next_empty_cell];
|
llama_kv_cell & cell = cells[next_empty_cell];
|
||||||
if (cell.is_empty()) { break; }
|
if (cell.is_empty()) { break; }
|
||||||
next_empty_cell += 1;
|
next_empty_cell += 1;
|
||||||
}
|
}
|
||||||
@ -237,10 +554,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
// gather and re-order
|
// gather and re-order
|
||||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||||
int32_t dst_id = s + min;
|
int32_t dst_id = s + min;
|
||||||
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
|
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
||||||
if (dst_id != src_id) {
|
if (dst_id != src_id) {
|
||||||
llama_kv_cell & dst_cell = cache.cells[dst_id];
|
llama_kv_cell & dst_cell = cells[dst_id];
|
||||||
llama_kv_cell & src_cell = cache.cells[src_id];
|
llama_kv_cell & src_cell = cells[src_id];
|
||||||
|
|
||||||
std::swap(dst_cell.pos, src_cell.pos);
|
std::swap(dst_cell.pos, src_cell.pos);
|
||||||
std::swap(dst_cell.src, src_cell.src);
|
std::swap(dst_cell.src, src_cell.src);
|
||||||
@ -248,10 +565,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
|
|
||||||
// swap tails (assuming they NEVER overlap)
|
// swap tails (assuming they NEVER overlap)
|
||||||
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
||||||
cache.cells[seq_id].tail = src_id;
|
cells[seq_id].tail = src_id;
|
||||||
}
|
}
|
||||||
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
||||||
cache.cells[seq_id].tail = dst_id;
|
cells[seq_id].tail = dst_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -260,7 +577,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||||
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||||
int32_t cell_id = s + min;
|
int32_t cell_id = s + min;
|
||||||
llama_kv_cell & cell = cache.cells[cell_id];
|
llama_kv_cell & cell = cells[cell_id];
|
||||||
|
|
||||||
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
||||||
// What should happen when the pos backtracks or skips a value?
|
// What should happen when the pos backtracks or skips a value?
|
||||||
@ -273,41 +590,41 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||||
cell.seq_id.insert(seq_id);
|
cell.seq_id.insert(seq_id);
|
||||||
cache.cells[seq_id].tail = cell_id;
|
cells[seq_id].tail = cell_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// allow getting the range of used cells, from head to head + n
|
// allow getting the range of used cells, from head to head + n
|
||||||
cache.head = min;
|
head = min;
|
||||||
cache.n = max - min + 1;
|
n = max - min + 1;
|
||||||
cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
|
used = std::count_if(cells.begin(), cells.end(),
|
||||||
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
||||||
|
|
||||||
// sanity check
|
// sanity check
|
||||||
return llama_kv_cache_slot_info(cache.n >= n_seqs);
|
return llama_kv_cache_slot_info(n >= n_seqs);
|
||||||
}
|
}
|
||||||
// otherwise, one cell per token.
|
// otherwise, one cell per token.
|
||||||
|
|
||||||
if (n_tokens > cache.size) {
|
if (n_tokens > size) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
|
||||||
return llama_kv_cache_slot_info_failed;
|
return llama_kv_cache_slot_info_failed;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_tested = 0;
|
uint32_t n_tested = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
if (cache.head + n_tokens > cache.size) {
|
if (head + n_tokens > size) {
|
||||||
n_tested += cache.size - cache.head;
|
n_tested += size - head;
|
||||||
cache.head = 0;
|
head = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool found = true;
|
bool found = true;
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
if (cache.cells[cache.head + i].pos >= 0) {
|
if (cells[head + i].pos >= 0) {
|
||||||
found = false;
|
found = false;
|
||||||
cache.head += i + 1;
|
head += i + 1;
|
||||||
n_tested += i + 1;
|
n_tested += i + 1;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -316,7 +633,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_tested >= cache.size) {
|
if (n_tested >= size) {
|
||||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||||
return llama_kv_cache_slot_info_failed;
|
return llama_kv_cache_slot_info_failed;
|
||||||
}
|
}
|
||||||
@ -325,22 +642,27 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|||||||
for (uint32_t s = 0; s < n_seqs; s++) {
|
for (uint32_t s = 0; s < n_seqs; s++) {
|
||||||
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
||||||
uint32_t k = s*n_seq_tokens + i;
|
uint32_t k = s*n_seq_tokens + i;
|
||||||
cache.cells[cache.head + k].pos = ubatch.pos[k];
|
cells[head + k].pos = ubatch.pos[k];
|
||||||
|
|
||||||
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
|
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
|
||||||
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
|
cells[head + k].seq_id.insert(ubatch.seq_id[s][j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cache.used += n_tokens;
|
used += n_tokens;
|
||||||
|
|
||||||
return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
|
return llama_kv_cache_slot_info(head, head + n_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) const {
|
||||||
for (uint32_t i = cache.size; i > 0; --i) {
|
// the FA kernels require padding to avoid extra runtime boundary checks
|
||||||
const llama_kv_cell & cell = cache.cells[i - 1];
|
return cparams.flash_attn ? 256u : 32u;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t llama_kv_cache::cell_max() const {
|
||||||
|
for (uint32_t i = size; i > 0; --i) {
|
||||||
|
const llama_kv_cell & cell = cells[i - 1];
|
||||||
|
|
||||||
if (cell.pos >= 0 && !cell.is_empty()) {
|
if (cell.pos >= 0 && !cell.is_empty()) {
|
||||||
return i;
|
return i;
|
||||||
|
@ -7,6 +7,9 @@
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
struct llama_cparams;
|
||||||
|
struct llama_ubatch;
|
||||||
|
|
||||||
struct llama_kv_cell {
|
struct llama_kv_cell {
|
||||||
llama_pos pos = -1;
|
llama_pos pos = -1;
|
||||||
llama_pos delta = 0;
|
llama_pos delta = 0;
|
||||||
@ -28,7 +31,19 @@ struct llama_kv_cell {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// a structure holds information about the slot found in llama_kv_cache_find_slot
|
||||||
|
struct llama_kv_cache_slot_info {
|
||||||
|
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
|
||||||
|
bool found = false; // the slot was found
|
||||||
|
|
||||||
|
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
|
||||||
|
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
|
||||||
|
|
||||||
|
operator bool() const { return found; }
|
||||||
|
};
|
||||||
|
|
||||||
// ring-buffer of cached KV data
|
// ring-buffer of cached KV data
|
||||||
|
// TODO: pimpl
|
||||||
struct llama_kv_cache {
|
struct llama_kv_cache {
|
||||||
bool has_shift = false;
|
bool has_shift = false;
|
||||||
bool do_defrag = false;
|
bool do_defrag = false;
|
||||||
@ -57,343 +72,8 @@ struct llama_kv_cache {
|
|||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
|
|
||||||
int32_t n_tokens() const {
|
// TODO: become constructor
|
||||||
int32_t result = 0;
|
bool init(
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; i++) {
|
|
||||||
result += cells[i].seq_id.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t total_size() const {
|
|
||||||
size_t size = 0;
|
|
||||||
for (const auto & buf : bufs) {
|
|
||||||
size += ggml_backend_buffer_get_size(buf.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: better data structures to reduce the cost of this operation
|
|
||||||
llama_pos max_pos() const {
|
|
||||||
llama_pos max_pos = -1;
|
|
||||||
for (const auto & cell : cells) {
|
|
||||||
max_pos = std::max(max_pos, cell.pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
return max_pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
void clear() {
|
|
||||||
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
|
||||||
cells[i].pos = -1;
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
cells[i].src = -1;
|
|
||||||
cells[i].tail = -1;
|
|
||||||
}
|
|
||||||
head = 0;
|
|
||||||
used = 0;
|
|
||||||
|
|
||||||
for (auto & buf : bufs) {
|
|
||||||
ggml_backend_buffer_clear(buf.get(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
||||||
uint32_t new_head = size;
|
|
||||||
|
|
||||||
if (p0 < 0) {
|
|
||||||
p0 = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (p1 < 0) {
|
|
||||||
p1 = std::numeric_limits<llama_pos>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
// models like Mamba or RWKV can't have a state partially erased
|
|
||||||
if (recurrent) {
|
|
||||||
if (seq_id >= (int64_t) size) {
|
|
||||||
// could be fatal
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (0 <= seq_id) {
|
|
||||||
int32_t & tail_id = cells[seq_id].tail;
|
|
||||||
if (tail_id >= 0) {
|
|
||||||
const llama_kv_cell & cell = cells[tail_id];
|
|
||||||
// partial intersection is invalid
|
|
||||||
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// invalidate tails which will be cleared
|
|
||||||
if (p0 <= cell.pos && cell.pos < p1) {
|
|
||||||
tail_id = -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// seq_id is negative, then the range should include everything or nothing
|
|
||||||
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
|
||||||
if (cells[i].pos >= p0 && cells[i].pos < p1) {
|
|
||||||
if (seq_id < 0) {
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
} else if (cells[i].has_seq_id(seq_id)) {
|
|
||||||
cells[i].seq_id.erase(seq_id);
|
|
||||||
} else {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (cells[i].is_empty()) {
|
|
||||||
// keep count of the number of used cells
|
|
||||||
if (cells[i].pos >= 0) {
|
|
||||||
used--;
|
|
||||||
}
|
|
||||||
|
|
||||||
cells[i].pos = -1;
|
|
||||||
cells[i].src = -1;
|
|
||||||
|
|
||||||
if (new_head == size) {
|
|
||||||
new_head = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
|
||||||
if (new_head != size && new_head < head) {
|
|
||||||
head = new_head;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
||||||
if (seq_id_src == seq_id_dst) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (p0 < 0) {
|
|
||||||
p0 = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (p1 < 0) {
|
|
||||||
p1 = std::numeric_limits<llama_pos>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (recurrent) {
|
|
||||||
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
|
||||||
llama_kv_cell & tail_src = cells[seq_id_src];
|
|
||||||
llama_kv_cell & tail_dst = cells[seq_id_dst];
|
|
||||||
if (tail_dst.tail >= 0) {
|
|
||||||
// clear destination seq_id if it wasn't empty
|
|
||||||
llama_kv_cell & cell_dst = cells[tail_dst.tail];
|
|
||||||
|
|
||||||
cell_dst.seq_id.erase(seq_id_dst);
|
|
||||||
tail_dst.tail = -1;
|
|
||||||
if (cell_dst.seq_id.empty()) {
|
|
||||||
cell_dst.pos = -1;
|
|
||||||
cell_dst.delta = -1;
|
|
||||||
cell_dst.src = -1;
|
|
||||||
used -= 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tail_src.tail >= 0) {
|
|
||||||
llama_kv_cell & cell_src = cells[tail_src.tail];
|
|
||||||
|
|
||||||
cell_src.seq_id.insert(seq_id_dst);
|
|
||||||
tail_dst.tail = tail_src.tail;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// otherwise, this is the KV of a Transformer-like model
|
|
||||||
head = 0;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
|
||||||
if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
|
||||||
cells[i].seq_id.insert(seq_id_dst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void seq_keep(llama_seq_id seq_id) {
|
|
||||||
uint32_t new_head = size;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
|
||||||
if (recurrent && (llama_seq_id) i != seq_id) {
|
|
||||||
cells[i].tail = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!cells[i].has_seq_id(seq_id)) {
|
|
||||||
if (cells[i].pos >= 0) {
|
|
||||||
used--;
|
|
||||||
}
|
|
||||||
|
|
||||||
cells[i].pos = -1;
|
|
||||||
cells[i].src = -1;
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
|
|
||||||
if (new_head == size){
|
|
||||||
new_head = i;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
cells[i].seq_id.insert(seq_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
|
||||||
if (new_head != size && new_head < head) {
|
|
||||||
head = new_head;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
|
||||||
if (delta == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t new_head = size;
|
|
||||||
|
|
||||||
if (p0 < 0) {
|
|
||||||
p0 = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (p1 < 0) {
|
|
||||||
p1 = std::numeric_limits<llama_pos>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there is no range then return early to avoid looping over the
|
|
||||||
if (p0 == p1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (recurrent) {
|
|
||||||
// for Mamba-like or RWKV models, only the pos needs to be shifted
|
|
||||||
if (0 <= seq_id && seq_id < (int64_t) size) {
|
|
||||||
const int32_t tail_id = cells[seq_id].tail;
|
|
||||||
if (tail_id >= 0) {
|
|
||||||
llama_kv_cell & cell = cells[tail_id];
|
|
||||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
|
||||||
cell.pos += delta;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
|
||||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
|
||||||
has_shift = true;
|
|
||||||
cells[i].pos += delta;
|
|
||||||
cells[i].delta += delta;
|
|
||||||
|
|
||||||
if (cells[i].pos < 0) {
|
|
||||||
if (!cells[i].is_empty()) {
|
|
||||||
used--;
|
|
||||||
}
|
|
||||||
cells[i].pos = -1;
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
if (new_head == size) {
|
|
||||||
new_head = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
|
||||||
// Otherwise we just start the next search from the beginning.
|
|
||||||
head = new_head != size ? new_head : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
||||||
if (d == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (p0 < 0) {
|
|
||||||
p0 = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (p1 < 0) {
|
|
||||||
p1 = std::numeric_limits<llama_pos>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there is no range then return early to avoid looping over the cache.
|
|
||||||
if (p0 == p1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (recurrent) {
|
|
||||||
// for Mamba-like or RWKV models, only the pos needs to be changed
|
|
||||||
if (0 <= seq_id && seq_id < (int64_t) size) {
|
|
||||||
const int32_t tail_id = cells[seq_id].tail;
|
|
||||||
if (tail_id >= 0) {
|
|
||||||
llama_kv_cell & cell = cells[tail_id];
|
|
||||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
|
||||||
cell.pos /= d;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
|
||||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
|
||||||
has_shift = true;
|
|
||||||
|
|
||||||
{
|
|
||||||
llama_pos p_old = cells[i].pos;
|
|
||||||
cells[i].pos /= d;
|
|
||||||
cells[i].delta += cells[i].pos - p_old;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) {
|
|
||||||
llama_pos result = 0;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
|
||||||
if (cells[i].has_seq_id(seq_id)) {
|
|
||||||
result = std::max(result, cells[i].pos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void defrag() {
|
|
||||||
if (!recurrent) {
|
|
||||||
do_defrag = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// a structure holds information about the slot found in llama_kv_cache_find_slot
|
|
||||||
struct llama_kv_cache_slot_info {
|
|
||||||
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
|
|
||||||
bool found = false; // the slot was found
|
|
||||||
|
|
||||||
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
|
|
||||||
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
|
|
||||||
|
|
||||||
operator bool() const { return found; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: maybe not needed
|
|
||||||
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
|
|
||||||
|
|
||||||
bool llama_kv_cache_init(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
ggml_type type_k,
|
ggml_type type_k,
|
||||||
@ -401,25 +81,38 @@ bool llama_kv_cache_init(
|
|||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
bool offload);
|
bool offload);
|
||||||
|
|
||||||
// find an empty slot of size "n_tokens" in the cache
|
int32_t n_tokens() const;
|
||||||
// updates the cache head
|
|
||||||
// returns a structure holding information about the slot found
|
|
||||||
// Note: On success, it's important that cache.head points
|
|
||||||
// to the first cell of the slot.
|
|
||||||
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
const struct llama_ubatch & batch);
|
|
||||||
|
|
||||||
// find how many cells are currently in use
|
size_t total_size() const;
|
||||||
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
|
|
||||||
|
|
||||||
//
|
// TODO: better data structures to reduce the cost of this operation
|
||||||
// kv cache view
|
llama_pos max_pos() const;
|
||||||
//
|
|
||||||
|
|
||||||
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
|
void clear();
|
||||||
|
|
||||||
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);
|
||||||
|
void seq_keep(llama_seq_id seq_id);
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d);
|
||||||
|
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id);
|
||||||
|
|
||||||
|
void defrag();
|
||||||
|
|
||||||
|
// find an empty slot of size "n_tokens" in the cache
|
||||||
|
// updates the cache head
|
||||||
|
// returns a structure holding information about the slot found
|
||||||
|
// Note: On success, it's important that cache.head points
|
||||||
|
// to the first cell of the slot.
|
||||||
|
llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
|
||||||
|
|
||||||
|
// TODO: maybe not needed
|
||||||
|
uint32_t get_padding(const llama_cparams & cparams) const;
|
||||||
|
|
||||||
|
// find how many cells are currently in use
|
||||||
|
uint32_t cell_max() const;
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// kv cache restore
|
// kv cache restore
|
||||||
@ -472,3 +165,10 @@ struct llama_kv_slot_restorer {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// kv cache view
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
|
||||||
|
|
||||||
|
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
|
||||||
|
@ -8572,18 +8572,18 @@ static int llama_decode_impl(
|
|||||||
kv_self.head = 0;
|
kv_self.head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
const auto slot_info = kv_self.find_slot(ubatch);
|
||||||
if (!slot) {
|
if (!slot_info) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
kv_slot_restorer.save(slot);
|
kv_slot_restorer.save(slot_info);
|
||||||
|
|
||||||
if (!kv_self.recurrent) {
|
if (!kv_self.recurrent) {
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
const uint32_t pad = kv_self.get_padding(cparams);
|
||||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
||||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -8969,7 +8969,7 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
|
|||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
|
const uint32_t n_kv = kv_self.cell_max();
|
||||||
const uint32_t n_used = kv_self.used;
|
const uint32_t n_used = kv_self.used;
|
||||||
|
|
||||||
assert(n_used <= n_kv);
|
assert(n_used <= n_kv);
|
||||||
@ -9540,7 +9540,7 @@ struct llama_context * llama_init_from_model(
|
|||||||
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||||
|
|
||||||
// this is necessary due to kv_self.n being padded later during inference
|
// this is necessary due to kv_self.n being padded later during inference
|
||||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->kv_self.get_padding(cparams));
|
||||||
|
|
||||||
// with causal attention, the batch size is limited by the context size
|
// with causal attention, the batch size is limited by the context size
|
||||||
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||||
@ -9682,7 +9682,7 @@ struct llama_context * llama_init_from_model(
|
|||||||
|
|
||||||
llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
|
llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
|
||||||
|
|
||||||
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, ctx->cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
if (!ctx->kv_self.init(ctx->model, ctx->cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
Loading…
Reference in New Issue
Block a user