mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 23:00:46 +01:00
llama : rethink recurrent state cell counts
* llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot
This commit is contained in:
parent
3b57b55c6f
commit
7e13f19fb5
586
llama.cpp
586
llama.cpp
@ -1753,6 +1753,9 @@ struct llama_hparams {
|
|||||||
uint32_t n_expert_used = 0;
|
uint32_t n_expert_used = 0;
|
||||||
uint32_t n_vocab_type = 0; // for BERT-style token types
|
uint32_t n_vocab_type = 0; // for BERT-style token types
|
||||||
|
|
||||||
|
// TODO: find a more compact way to add more per-layer hyper-parameters
|
||||||
|
std::vector<int32_t> n_head_kv_vec;
|
||||||
|
|
||||||
float f_norm_eps;
|
float f_norm_eps;
|
||||||
float f_norm_rms_eps;
|
float f_norm_rms_eps;
|
||||||
|
|
||||||
@ -1793,6 +1796,8 @@ struct llama_hparams {
|
|||||||
if (this->n_expert != other.n_expert) return true;
|
if (this->n_expert != other.n_expert) return true;
|
||||||
if (this->n_expert_used != other.n_expert_used) return true;
|
if (this->n_expert_used != other.n_expert_used) return true;
|
||||||
|
|
||||||
|
if (this->n_head_kv_vec != other.n_head_kv_vec) return true;
|
||||||
|
|
||||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||||
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
||||||
|
|
||||||
@ -1812,29 +1817,46 @@ struct llama_hparams {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_gqa() const {
|
uint32_t n_head_kv_l(uint32_t layer) const {
|
||||||
|
if (layer < n_head_kv_vec.size()) {
|
||||||
|
int32_t n_hkv_l = n_head_kv_vec[layer];
|
||||||
|
// TODO: what should happen when it's negative?
|
||||||
|
GGML_ASSERT(n_hkv_l >= 0);
|
||||||
|
return n_hkv_l;
|
||||||
|
}
|
||||||
|
return n_head_kv;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t n_gqa(uint32_t layer = 0) const {
|
||||||
|
uint32_t n_head_kv = n_head_kv_l(layer);
|
||||||
if (n_head_kv == 0) {
|
if (n_head_kv == 0) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
return n_head/n_head_kv;
|
return n_head/n_head_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads
|
uint32_t n_embd_k_gqa(uint32_t layer = 0) const { // dimension of key embeddings across all k-v heads
|
||||||
|
uint32_t n_head_kv = n_head_kv_l(layer);
|
||||||
return n_embd_head_k * n_head_kv;
|
return n_embd_head_k * n_head_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
|
uint32_t n_embd_v_gqa(uint32_t layer = 0) const { // dimension of value embeddings across all k-v heads
|
||||||
|
uint32_t n_head_kv = n_head_kv_l(layer);
|
||||||
return n_embd_head_v * n_head_kv;
|
return n_embd_head_v * n_head_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_embd_r() const { // dimension of the rolling state embeddings
|
uint32_t n_embd_r(uint32_t layer) const { // dimension of the rolling state embeddings
|
||||||
|
// TODO: support using an SSM in place of the MLP of a Transformer
|
||||||
|
if (n_head_kv_l(layer) != 0) { return 0; }
|
||||||
// corresponds to Mamba's conv_states size
|
// corresponds to Mamba's conv_states size
|
||||||
// TODO: maybe support other convolution strides than 1
|
// TODO: maybe support other convolution strides than 1
|
||||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_embd_s() const { // dimension of the recurrent state embeddings
|
uint32_t n_embd_s(uint32_t layer) const { // dimension of the recurrent state embeddings
|
||||||
|
// TODO: support using an SSM in place of the MLP of a Transformer
|
||||||
|
if (n_head_kv_l(layer) != 0) { return 0; }
|
||||||
// corresponds to Mamba's ssm_states size
|
// corresponds to Mamba's ssm_states size
|
||||||
return ssm_d_state * ssm_d_inner;
|
return ssm_d_state * ssm_d_inner;
|
||||||
}
|
}
|
||||||
@ -2078,10 +2100,12 @@ struct llama_rs_cache {
|
|||||||
// computed when finding a slot
|
// computed when finding a slot
|
||||||
uint32_t n = 0; // range of states used for the last slot
|
uint32_t n = 0; // range of states used for the last slot
|
||||||
|
|
||||||
// useful to know the minimum reserved cell count per seq_id
|
// only counts cells which are tails of all of their sequences.
|
||||||
// only counts sequences which have a non-shared tail
|
// useful to know the minimum reserved cell count per seq_id.
|
||||||
uint32_t n_seqs = 0;
|
uint32_t n_seqs = 0;
|
||||||
// cells part of multiple sequences AND which have at least one tail
|
// cells part of multiple sequences,
|
||||||
|
// but which are only the tail of some of them.
|
||||||
|
// useful to dismiss sequences used as a shared prompt
|
||||||
uint32_t n_shared_tail_cells = 0;
|
uint32_t n_shared_tail_cells = 0;
|
||||||
|
|
||||||
// with state models, a cell can hold the state for more than one past token
|
// with state models, a cell can hold the state for more than one past token
|
||||||
@ -2279,10 +2303,8 @@ struct llama_rs_cache {
|
|||||||
for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) {
|
for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) {
|
||||||
llama_rs_cell & rs_cell = cells[cell_id];
|
llama_rs_cell & rs_cell = cells[cell_id];
|
||||||
if (!rs_cell.seq_nodes.empty()) {
|
if (!rs_cell.seq_nodes.empty()) {
|
||||||
if (rs_cell.seq_nodes.size() == 1) {
|
if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) {
|
||||||
if (rs_cell.tail_rc == 1) {
|
n_seqs_verif += 1;
|
||||||
n_seqs_verif += 1;
|
|
||||||
}
|
|
||||||
} else if (rs_cell.tail_rc > 0) {
|
} else if (rs_cell.tail_rc > 0) {
|
||||||
n_shared_tail_cells_verif += 1;
|
n_shared_tail_cells_verif += 1;
|
||||||
}
|
}
|
||||||
@ -2308,9 +2330,11 @@ struct llama_rs_cache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed.
|
// returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed.
|
||||||
|
// Why an iterator? Because it allows using std::vector<T>::erase.
|
||||||
std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) {
|
std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) {
|
||||||
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
|
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
|
||||||
// TODO: assert the iterator points inside the correct vector
|
// The iterator needs to point inside the correct vector
|
||||||
|
GGML_ASSERT(node_iter.base() >= rs_cell.seq_nodes.data() && node_iter.base() < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size());
|
||||||
if (node_iter != rs_cell.seq_nodes.end()) {
|
if (node_iter != rs_cell.seq_nodes.end()) {
|
||||||
// update the tree
|
// update the tree
|
||||||
llama_rs_seq_node node = *node_iter;
|
llama_rs_seq_node node = *node_iter;
|
||||||
@ -2325,12 +2349,20 @@ struct llama_rs_cache {
|
|||||||
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
|
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
|
||||||
prev_node->next_cell = node.next_cell;
|
prev_node->next_cell = node.next_cell;
|
||||||
if (node.is_tail()) {
|
if (node.is_tail()) {
|
||||||
|
// move the tail back to the previous cell
|
||||||
if (prev_cell.seq_nodes.size() > 1) {
|
if (prev_cell.seq_nodes.size() > 1) {
|
||||||
if (prev_cell.tail_rc == 0) {
|
if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) {
|
||||||
n_shared_tail_cells += 1;
|
if (prev_cell.tail_rc == 0) {
|
||||||
}
|
n_shared_tail_cells += 1;
|
||||||
if (rs_cell.seq_nodes.size() == 1) {
|
}
|
||||||
n_seqs -= 1;
|
|
||||||
|
// o oo oo
|
||||||
|
// |/ -> o/
|
||||||
|
// | |
|
||||||
|
// e.g. when removing the leaf with a single tail
|
||||||
|
if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) {
|
||||||
|
n_seqs -= 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prev_cell.tail_rc += 1;
|
prev_cell.tail_rc += 1;
|
||||||
@ -2341,17 +2373,22 @@ struct llama_rs_cache {
|
|||||||
if (node.is_tail()) {
|
if (node.is_tail()) {
|
||||||
seq.tail = rs_cell.prev;
|
seq.tail = rs_cell.prev;
|
||||||
if (rs_cell.tail_rc == 1) {
|
if (rs_cell.tail_rc == 1) {
|
||||||
if (rs_cell.seq_nodes.size() > 1) {
|
if (seq.tail < 0) {
|
||||||
// assuming the previous cell of a shared cell is also shared,
|
|
||||||
// this was a shared tail cell, but will no longer be a tail cell
|
|
||||||
n_shared_tail_cells -= 1;
|
|
||||||
} else if (seq.tail < 0) {
|
|
||||||
// no more tail, no more sequence
|
// no more tail, no more sequence
|
||||||
n_seqs -= 1;
|
if (rs_cell.seq_nodes.size() > 1) {
|
||||||
|
n_shared_tail_cells -= 1;
|
||||||
|
} else {
|
||||||
|
n_seqs -= 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GGML_ASSERT(rs_cell.tail_rc > 0);
|
GGML_ASSERT(rs_cell.tail_rc > 0);
|
||||||
rs_cell.tail_rc -= 1;
|
rs_cell.tail_rc -= 1;
|
||||||
|
} else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) {
|
||||||
|
// will fully become a tail cell
|
||||||
|
if (rs_cell.tail_rc > 0) {
|
||||||
|
n_seqs += 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (node_iter == rs_cell.seq_nodes.begin()) {
|
if (node_iter == rs_cell.seq_nodes.begin()) {
|
||||||
// this seq_id was the first in the list
|
// this seq_id was the first in the list
|
||||||
@ -2363,14 +2400,6 @@ struct llama_rs_cache {
|
|||||||
if ((uint32_t) next_node->seq_id < seq_tails.size()) {
|
if ((uint32_t) next_node->seq_id < seq_tails.size()) {
|
||||||
auto & next_seq = seq_tails[next_node->seq_id];
|
auto & next_seq = seq_tails[next_node->seq_id];
|
||||||
next_seq.n_cells += 1;
|
next_seq.n_cells += 1;
|
||||||
// only the tail ref count from the other seq_ids are left in tail_rc
|
|
||||||
if (rs_cell.tail_rc > 0) {
|
|
||||||
// will become a non-shared cell
|
|
||||||
if (rs_cell.seq_nodes.size() == 2) {
|
|
||||||
n_shared_tail_cells -= 1;
|
|
||||||
n_seqs += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "invalid seq_id");
|
GGML_ASSERT(false && "invalid seq_id");
|
||||||
}
|
}
|
||||||
@ -2433,43 +2462,41 @@ struct llama_rs_cache {
|
|||||||
rs_cell.pos = prev_cell.pos + 1;
|
rs_cell.pos = prev_cell.pos + 1;
|
||||||
rs_cell.src = prev_cell.src;
|
rs_cell.src = prev_cell.src;
|
||||||
}
|
}
|
||||||
prev_cell.tail_rc -= 1;
|
|
||||||
prev_node->next_cell = i_cell;
|
prev_node->next_cell = i_cell;
|
||||||
rs_cell.prev = prev;
|
rs_cell.prev = prev;
|
||||||
if (seq.tail == prev) {
|
if (seq.tail == prev) {
|
||||||
// What to do when the tail moves...
|
// What to do when the tail moves...
|
||||||
// from unique to shared (n_seqs--)
|
// (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _)
|
||||||
// if the new cell has one seq_id or has no tails (n_shared_tail_cells++)
|
// O -> oO (n_seqs--, n_shared_tail_cells++)
|
||||||
// if the new cell has one seq_id and a tail (n_seqs-- (yes, another time))
|
// O -> O (seq.n_cells++)
|
||||||
// from unique to unique (seq.n_cells++)
|
// OO+ -> oO (n_seqs--, n_shared_tail_cells += 2)
|
||||||
// from empty to unique (seq.n_cells++, n_seqs++)
|
// OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+))
|
||||||
// from empty to shared
|
// _ -> oO (n_shared_tail_cells++)
|
||||||
// if the new cell only has one seq_id or has no tail (n_shared_tail_cells++)
|
// _ -> O (seq.n_cells++, n_seqs++)
|
||||||
// if the new cell only has one seq_id and has one tail (n_seqs--)
|
// Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--)
|
||||||
// from shared to shared
|
// Oo -> OO+ (n_shared_tail_cell--)
|
||||||
// if the last cell has no tails (n_shared_tail_cells--)
|
// OOo -> O (seq.n_cells++, n_seqs++)
|
||||||
// if the new cell has no tails or has one seq_id (n_shared_tail_cells++)
|
if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) {
|
||||||
// if the new cell only has one seq_id and has one tail (n_seqs--)
|
// from fully tail
|
||||||
// from shared to unique (seq.n_cells++)
|
if (prev_cell.tail_rc > 1) {
|
||||||
// if this seq_id was not the first of the last cell (n_seqs++)
|
// the previous tail becomes shared with a non-tail
|
||||||
// if the last cell has no tails (n_shared_tail_cells--)
|
n_shared_tail_cells += 1;
|
||||||
if (prev_cell.seq_nodes.size() > 1) {
|
|
||||||
// from shared
|
|
||||||
if (rs_cell.is_empty()) {
|
|
||||||
// to unique
|
|
||||||
if (prev_cell.seq_nodes[0].seq_id != id) {
|
|
||||||
n_seqs += 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// the previous cell is no longer a shared tail
|
if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) {
|
||||||
if (prev_cell.tail_rc == 0) {
|
// the new tail cell was previously a fully non-tail cell
|
||||||
|
n_shared_tail_cells += 1;
|
||||||
|
n_seqs -= 1;
|
||||||
|
}
|
||||||
|
} else if (rs_cell.is_empty()) {
|
||||||
|
// from shared to unique
|
||||||
|
n_seqs += 1;
|
||||||
|
if (prev_cell.tail_rc == 1) {
|
||||||
|
// it was the last tail of the previous cell
|
||||||
n_shared_tail_cells -= 1;
|
n_shared_tail_cells -= 1;
|
||||||
}
|
}
|
||||||
} else if (!rs_cell.is_empty()) {
|
|
||||||
// from unique to shared
|
|
||||||
n_seqs -= 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
prev_cell.tail_rc -= 1;
|
||||||
}
|
}
|
||||||
if (rs_cell.is_empty()) {
|
if (rs_cell.is_empty()) {
|
||||||
// to unique
|
// to unique
|
||||||
@ -2482,15 +2509,10 @@ struct llama_rs_cache {
|
|||||||
rs_cell.src = -1;
|
rs_cell.src = -1;
|
||||||
}
|
}
|
||||||
used += 1;
|
used += 1;
|
||||||
} else {
|
} else if (rs_cell.tail_rc == 0) {
|
||||||
// to shared
|
// to shared
|
||||||
if (rs_cell.seq_nodes.size() == 1) {
|
if (seq.tail < 0) {
|
||||||
// a lone tail becomes a shared cell
|
// from empty to shared
|
||||||
if (rs_cell.tail_rc > 0) {
|
|
||||||
n_seqs -= 1;
|
|
||||||
}
|
|
||||||
n_shared_tail_cells += 1;
|
|
||||||
} else if (rs_cell.tail_rc == 0) {
|
|
||||||
n_shared_tail_cells += 1;
|
n_shared_tail_cells += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2910,26 +2932,18 @@ static bool llama_cache_init(
|
|||||||
const llama_context * ctx,
|
const llama_context * ctx,
|
||||||
ggml_type type_k,
|
ggml_type type_k,
|
||||||
ggml_type type_v,
|
ggml_type type_v,
|
||||||
uint32_t n_ctx,
|
|
||||||
uint32_t n_seq_max,
|
|
||||||
bool offload) {
|
bool offload) {
|
||||||
const llama_model & model = ctx->model;
|
const llama_model & model = ctx->model;
|
||||||
const llama_cparams & cparams = ctx->cparams;
|
const llama_cparams & cparams = ctx->cparams;
|
||||||
|
|
||||||
const struct llama_hparams & hparams = model.hparams;
|
const struct llama_hparams & hparams = model.hparams;
|
||||||
|
|
||||||
// TODO: per layer n_embd_*
|
const bool has_kv = hparams.n_head_kv != 0 && hparams.causal_attn;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const bool has_r = hparams.ssm_d_conv != 0 && hparams.ssm_d_inner != 0;
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const bool has_s = hparams.ssm_d_state != 0 && hparams.ssm_d_inner != 0;
|
||||||
const uint32_t n_embd_r = hparams.n_embd_r();
|
|
||||||
const uint32_t n_embd_s = hparams.n_embd_s();
|
|
||||||
const bool has_kv = hparams.n_head != 0 && hparams.causal_attn;
|
|
||||||
const bool has_r = n_embd_r != 0;
|
|
||||||
const bool has_s = n_embd_s != 0;
|
|
||||||
const bool has_rs = has_r || has_s;
|
const bool has_rs = has_r || has_s;
|
||||||
const uint32_t kv_size = has_kv ? n_ctx : 0;
|
const uint32_t kv_size = has_kv ? cparams.n_ctx : 0;
|
||||||
const uint32_t rs_size = has_rs ? n_seq_max : 0;
|
const uint32_t rs_size = has_rs ? cparams.n_seq_max : 0;
|
||||||
// TODO: per cache type layer count
|
|
||||||
const int64_t n_layer = hparams.n_layer;
|
const int64_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
cache.kv.size = kv_size;
|
cache.kv.size = kv_size;
|
||||||
@ -2967,6 +2981,7 @@ static bool llama_cache_init(
|
|||||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||||
for (auto & it : buft_layer_count) {
|
for (auto & it : buft_layer_count) {
|
||||||
int n_layers = it.second;
|
int n_layers = it.second;
|
||||||
|
// TODO: for mixed architectures, avoid allocating empty recurrent state or kv cache tensors
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(),
|
/*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
@ -2995,20 +3010,20 @@ static bool llama_cache_init(
|
|||||||
for (int i = 0; i < (int) n_layer; i++) {
|
for (int i = 0; i < (int) n_layer; i++) {
|
||||||
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||||
if (has_kv) {
|
if (has_kv) {
|
||||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size);
|
||||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size);
|
||||||
ggml_format_name(k, "cache_k_l%d", i);
|
ggml_format_name(k, "cache_k_l%d", i);
|
||||||
ggml_format_name(v, "cache_v_l%d", i);
|
ggml_format_name(v, "cache_v_l%d", i);
|
||||||
cache.kv.k_l.push_back(k);
|
cache.kv.k_l.push_back(k);
|
||||||
cache.kv.v_l.push_back(v);
|
cache.kv.v_l.push_back(v);
|
||||||
}
|
}
|
||||||
if (has_r) {
|
if (has_r) {
|
||||||
ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_r*rs_size);
|
ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size);
|
||||||
ggml_format_name(r, "cache_r_l%d", i);
|
ggml_format_name(r, "cache_r_l%d", i);
|
||||||
cache.rs.r_l.push_back(r);
|
cache.rs.r_l.push_back(r);
|
||||||
}
|
}
|
||||||
if (has_s) {
|
if (has_s) {
|
||||||
ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_s*rs_size);
|
ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size);
|
||||||
ggml_format_name(s, "cache_s_l%d", i);
|
ggml_format_name(s, "cache_s_l%d", i);
|
||||||
cache.rs.s_l.push_back(s);
|
cache.rs.s_l.push_back(s);
|
||||||
}
|
}
|
||||||
@ -3024,7 +3039,7 @@ static bool llama_cache_init(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ggml_backend_buffer_clear(buf, 0);
|
ggml_backend_buffer_clear(buf, 0);
|
||||||
LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
LLAMA_LOG_INFO("%s: %10s cache buf size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||||
cache.bufs.push_back(buf);
|
cache.bufs.push_back(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3042,177 +3057,21 @@ static bool llama_cache_find_slot(
|
|||||||
const uint32_t rs_size = cache.rs.size;
|
const uint32_t rs_size = cache.rs.size;
|
||||||
const uint32_t n_tokens = batch.n_tokens;
|
const uint32_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
// FIXME: on failure, leave all caches in a consistent state.
|
// only check first, to allow failing gracefully
|
||||||
|
|
||||||
if (rs_size > 0) {
|
if (rs_size > 0) {
|
||||||
// For recurrent state architectures (like Mamba),
|
// everything should fit if all seq_ids are smaller than the max
|
||||||
// each cache cell can store the state for a whole sequence.
|
|
||||||
// TODO: find a way to always make the rs slot contiguous
|
|
||||||
|
|
||||||
llama_seq_id min_seq = cache.rs.size - 1;
|
|
||||||
llama_seq_id max_seq = 0;
|
|
||||||
uint32_t min_cell = cache.rs.size - 1;
|
|
||||||
uint32_t max_cell = 0;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||||
int32_t target_cell = -1; // ensure all the sequences of a token get the same cell
|
int32_t n_seq_id = batch.n_seq_id[i];
|
||||||
int32_t n_seq_ids = batch.n_seq_id[i];
|
for (int32_t j = 0; j < n_seq_id; ++j) {
|
||||||
for (int32_t j = 0; j < n_seq_ids; ++j) {
|
|
||||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||||
bool need_new_cell = false;
|
|
||||||
// Everything should fit assuming the biggest seq_id < rs_size
|
|
||||||
if ((uint32_t) seq_id < rs_size) {
|
|
||||||
llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id];
|
|
||||||
if (seq_id > max_seq) { max_seq = seq_id; }
|
|
||||||
if (seq_id < min_seq) { min_seq = seq_id; }
|
|
||||||
|
|
||||||
if (!seq.in_ubatch && target_cell >= 0) {
|
if (seq_id < 0 || (uint32_t) seq_id >= rs_size) {
|
||||||
// never saw this seq_id before,
|
|
||||||
// but there's already a cell reserved for this token, use it
|
|
||||||
cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id);
|
|
||||||
} else if (seq.tail < 0) {
|
|
||||||
need_new_cell = true;
|
|
||||||
} else {
|
|
||||||
llama_rs_cell & tail = cache.rs.cells[seq.tail];
|
|
||||||
if (seq.in_ubatch) {
|
|
||||||
// this seq_id was already seen before in the batch
|
|
||||||
// assuming the tail cell already "has" this seq_id
|
|
||||||
tail.pos += 1;
|
|
||||||
target_cell = seq.tail;
|
|
||||||
} else {
|
|
||||||
// first time this sequence is seen,
|
|
||||||
// there's no reserved cell yet;
|
|
||||||
// if it's not the first sequence of the token, how could it even get here?
|
|
||||||
GGML_ASSERT(j == 0);
|
|
||||||
|
|
||||||
bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids;
|
|
||||||
if (has_same_seqs) {
|
|
||||||
// the tail cell of a seq_id is assumed to already be part of the seq_id,
|
|
||||||
// hence the skip of the first seq_id
|
|
||||||
for (int32_t k = 1; k < n_seq_ids; ++k) {
|
|
||||||
if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) {
|
|
||||||
has_same_seqs = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: make the checkpoint interval configurable
|
|
||||||
if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) {
|
|
||||||
// a checkpoint should be saved
|
|
||||||
need_new_cell = true;
|
|
||||||
} else {
|
|
||||||
// re-use last tail
|
|
||||||
tail.pos += 1;
|
|
||||||
target_cell = seq.tail;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (need_new_cell && target_cell < 0) {
|
|
||||||
const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq);
|
|
||||||
|
|
||||||
uint32_t cell_id = cache.rs.size;
|
|
||||||
bool looped_once = false;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (cache.rs.head >= cache.rs.size) {
|
|
||||||
cache.rs.head = 0;
|
|
||||||
if (looped_once) {
|
|
||||||
// avoid infinite loop
|
|
||||||
// NOTE: this should not happen, but gracefully fail anyway
|
|
||||||
LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
looped_once = true;
|
|
||||||
}
|
|
||||||
cell_id = cache.rs.head;
|
|
||||||
llama_rs_cell & candidate = cache.rs.cells[cell_id];
|
|
||||||
if (candidate.is_empty()) { break; }
|
|
||||||
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
|
|
||||||
// the candidate is the old tail
|
|
||||||
if (candidate.seq_nodes.size() > 1) {
|
|
||||||
// prune out the other seq_ids, because they diverge
|
|
||||||
// TODO(maybe): hande this in insert_seq_tail_to_cell_id
|
|
||||||
// (hopefully doesn't happen too often)
|
|
||||||
for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) {
|
|
||||||
if (node_iter->seq_id == seq_id) {
|
|
||||||
node_iter = std::next(node_iter);
|
|
||||||
} else {
|
|
||||||
node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// re-use the tail cell to avoid not finding anything
|
|
||||||
candidate.pos += 1;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (candidate.tail_rc > 0) {
|
|
||||||
// skip tails of other sequences
|
|
||||||
cache.rs.head += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (candidate.seq_nodes.size() > 1) {
|
|
||||||
// shared prompts are not usually backtracked, so they can be pruned
|
|
||||||
cache.rs.clear_cell(candidate);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// prune too-long sequences
|
|
||||||
llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id;
|
|
||||||
if (seq_id_to_prune == seq_id) {
|
|
||||||
// TODO: selectively skip some cells to keep older states
|
|
||||||
cache.rs.clear_cell(candidate);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size());
|
|
||||||
auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune];
|
|
||||||
if (seq_to_prune.n_cells > min_cells_per_seq) {
|
|
||||||
cache.rs.clear_cell(candidate);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
cache.rs.head += 1;
|
|
||||||
}
|
|
||||||
if (cell_id < cache.rs.size) {
|
|
||||||
cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id);
|
|
||||||
target_cell = cell_id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (seq.tail >= 0) {
|
|
||||||
if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; }
|
|
||||||
if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; }
|
|
||||||
seq.in_ubatch = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assuming the tokens are in-order
|
|
||||||
if (batch.pos[i] != cache.rs.cells[seq.tail].pos) {
|
|
||||||
// What should happen when the pos backtracks or skips a value?
|
|
||||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
|
||||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
|
|
||||||
__func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// too big seq_id
|
// too big seq_id
|
||||||
// TODO: would it be possible to resize the rs cache size instead?
|
// TODO: would it be possible to resize the rs cache size instead?
|
||||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size);
|
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cache.rs.head = target_cell + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (llama_seq_id i = min_seq; i <= max_seq; ++i) {
|
|
||||||
// make sure it's cleared for next time
|
|
||||||
cache.rs.seq_tails[i].in_ubatch = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// allow getting the range of used cells, from head to head + n
|
|
||||||
cache.rs.head = min_cell;
|
|
||||||
cache.rs.n = max_cell - min_cell + 1;
|
|
||||||
|
|
||||||
// sanity check
|
|
||||||
if (max_seq < min_seq || max_cell < min_cell) {
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3257,7 +3116,174 @@ static bool llama_cache_find_slot(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// now modification can be done, and should NOT fail
|
||||||
|
|
||||||
|
if (rs_size > 0) {
|
||||||
|
// For recurrent state architectures (like Mamba),
|
||||||
|
// each cache cell can store the state for a whole sequence.
|
||||||
|
// TODO: find a way to always make the rs slot contiguous
|
||||||
|
|
||||||
|
llama_seq_id min_seq = cache.rs.size - 1;
|
||||||
|
llama_seq_id max_seq = 0;
|
||||||
|
uint32_t min_cell = cache.rs.size - 1;
|
||||||
|
uint32_t max_cell = 0;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||||
|
int32_t target_cell = -1; // ensure all the sequences of a token get the same cell
|
||||||
|
int32_t n_seq_ids = batch.n_seq_id[i];
|
||||||
|
for (int32_t j = 0; j < n_seq_ids; ++j) {
|
||||||
|
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||||
|
bool need_new_cell = false;
|
||||||
|
// Everything should fit assuming the biggest seq_id < rs_size
|
||||||
|
GGML_ASSERT((uint32_t) seq_id < rs_size);
|
||||||
|
llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id];
|
||||||
|
if (seq_id > max_seq) { max_seq = seq_id; }
|
||||||
|
if (seq_id < min_seq) { min_seq = seq_id; }
|
||||||
|
|
||||||
|
if (!seq.in_ubatch && target_cell >= 0) {
|
||||||
|
// never saw this seq_id before,
|
||||||
|
// but there's already a cell reserved for this token, use it
|
||||||
|
cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id);
|
||||||
|
} else if (seq.tail < 0) {
|
||||||
|
// this seq_id has no tail (and is empty)
|
||||||
|
need_new_cell = true;
|
||||||
|
} else {
|
||||||
|
llama_rs_cell & tail = cache.rs.cells[seq.tail];
|
||||||
|
if (seq.in_ubatch) {
|
||||||
|
// this seq_id was already seen before in the batch
|
||||||
|
// assuming the tail cell already "has" this seq_id
|
||||||
|
tail.pos += 1;
|
||||||
|
target_cell = seq.tail;
|
||||||
|
} else {
|
||||||
|
// first time this sequence is seen,
|
||||||
|
// there's no reserved cell yet;
|
||||||
|
// if it's not the first sequence of the token, how could it even get here?
|
||||||
|
GGML_ASSERT(j == 0);
|
||||||
|
|
||||||
|
bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids;
|
||||||
|
if (has_same_seqs) {
|
||||||
|
// the tail cell of a seq_id is assumed to already be part of the seq_id,
|
||||||
|
// hence the skip of the first seq_id
|
||||||
|
for (int32_t k = 1; k < n_seq_ids; ++k) {
|
||||||
|
if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) {
|
||||||
|
has_same_seqs = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: make the checkpoint interval configurable
|
||||||
|
if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) {
|
||||||
|
// a checkpoint should be saved
|
||||||
|
need_new_cell = true;
|
||||||
|
} else {
|
||||||
|
// re-use last tail
|
||||||
|
tail.pos += 1;
|
||||||
|
target_cell = seq.tail;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reserve a cell for this seq_id
|
||||||
|
if (need_new_cell && target_cell < 0) {
|
||||||
|
const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq);
|
||||||
|
|
||||||
|
uint32_t cell_id = cache.rs.size;
|
||||||
|
bool looped_once = false;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (cache.rs.head >= cache.rs.size) {
|
||||||
|
cache.rs.head = 0;
|
||||||
|
// avoid infinite loop
|
||||||
|
// NOTE: this should not fail; if it does, it's a bug.
|
||||||
|
GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not.");
|
||||||
|
looped_once = true;
|
||||||
|
}
|
||||||
|
cell_id = cache.rs.head;
|
||||||
|
llama_rs_cell & candidate = cache.rs.cells[cell_id];
|
||||||
|
if (candidate.is_empty()) { break; }
|
||||||
|
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
|
||||||
|
// the candidate is the old tail
|
||||||
|
if (candidate.seq_nodes.size() > 1) {
|
||||||
|
// prune out the other seq_ids, because they diverge
|
||||||
|
// TODO(maybe): hande this in insert_seq_tail_to_cell_id
|
||||||
|
// (hopefully doesn't happen too often)
|
||||||
|
for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) {
|
||||||
|
if (node_iter->seq_id == seq_id) {
|
||||||
|
node_iter = std::next(node_iter);
|
||||||
|
} else {
|
||||||
|
node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// re-use the tail cell to avoid not finding anything
|
||||||
|
candidate.pos += 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (candidate.tail_rc > 0) {
|
||||||
|
// skip tails of other sequences
|
||||||
|
cache.rs.head += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (candidate.seq_nodes.size() > 1) {
|
||||||
|
// shared prompts are not usually backtracked, so they can be pruned
|
||||||
|
cache.rs.clear_cell(candidate);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// prune too-long sequences
|
||||||
|
llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id;
|
||||||
|
if (seq_id_to_prune == seq_id) {
|
||||||
|
// TODO: selectively skip some cells to keep older states
|
||||||
|
cache.rs.clear_cell(candidate);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size());
|
||||||
|
auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune];
|
||||||
|
if (seq_to_prune.n_cells > min_cells_per_seq) {
|
||||||
|
cache.rs.clear_cell(candidate);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
cache.rs.head += 1;
|
||||||
|
}
|
||||||
|
if (cell_id < cache.rs.size) {
|
||||||
|
cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id);
|
||||||
|
target_cell = cell_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq.tail >= 0) {
|
||||||
|
if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; }
|
||||||
|
if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; }
|
||||||
|
seq.in_ubatch = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assuming the tokens are in-order
|
||||||
|
if (batch.pos[i] != cache.rs.cells[seq.tail].pos) {
|
||||||
|
// What should happen when the pos backtracks or skips a value?
|
||||||
|
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||||
|
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
|
||||||
|
__func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache.rs.head = target_cell + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (llama_seq_id i = min_seq; i <= max_seq; ++i) {
|
||||||
|
// make sure it's cleared for next time
|
||||||
|
cache.rs.seq_tails[i].in_ubatch = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow getting the range of used cells, from head to head + n
|
||||||
|
cache.rs.head = min_cell;
|
||||||
|
cache.rs.n = max_cell - min_cell + 1;
|
||||||
|
|
||||||
|
// sanity check
|
||||||
|
GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kv_size > 0) {
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
cache.kv.cells[cache.kv.head + i].pos = batch.pos[i];
|
cache.kv.cells[cache.kv.head + i].pos = batch.pos[i];
|
||||||
|
|
||||||
@ -4194,9 +4220,9 @@ struct llama_model_loader {
|
|||||||
bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true) {
|
bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true) {
|
||||||
const int kid = gguf_find_key(meta, key.c_str());
|
const int kid = gguf_find_key(meta, key.c_str());
|
||||||
|
|
||||||
if (kid < 0) {
|
if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
|
||||||
if (required) {
|
if (required) {
|
||||||
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -4204,16 +4230,17 @@ struct llama_model_loader {
|
|||||||
struct GGUFMeta::ArrayInfo arr_info =
|
struct GGUFMeta::ArrayInfo arr_info =
|
||||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
|
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
|
||||||
|
|
||||||
if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) {
|
// TODO: allow ANY lossless cast
|
||||||
throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str()));
|
// GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
|
||||||
|
switch (arr_info.gt) {
|
||||||
|
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||||
|
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value)); break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
|
result.reserve(arr_info.length);
|
||||||
GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same<T, float>::value));
|
result.assign((const T *)arr_info.data, (const T *)arr_info.data + arr_info.length);
|
||||||
GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same<T, int>::value));
|
|
||||||
|
|
||||||
result.resize(arr_info.length);
|
|
||||||
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -4750,7 +4777,12 @@ static void llm_load_hparams(
|
|||||||
|
|
||||||
// n_head_kv is optional, default to n_head
|
// n_head_kv is optional, default to n_head
|
||||||
hparams.n_head_kv = hparams.n_head;
|
hparams.n_head_kv = hparams.n_head;
|
||||||
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
|
|
||||||
|
// per-layer n_head_kv
|
||||||
|
if (!ml.get_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_vec, false)) {
|
||||||
|
// global/fallback n_head_kv
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
|
||||||
|
}
|
||||||
|
|
||||||
bool rope_finetuned = false;
|
bool rope_finetuned = false;
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
|
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
|
||||||
@ -6704,10 +6736,7 @@ static bool llm_load_tensors(
|
|||||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
||||||
|
|
||||||
const int64_t n_ff = hparams.n_ff;
|
|
||||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
@ -7198,8 +7227,8 @@ static void llm_build_kv_store(
|
|||||||
int64_t il) {
|
int64_t il) {
|
||||||
const int64_t n_ctx = cparams.n_ctx;
|
const int64_t n_ctx = cparams.n_ctx;
|
||||||
|
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
GGML_ASSERT(kv.size == n_ctx);
|
GGML_ASSERT(kv.size == n_ctx);
|
||||||
|
|
||||||
@ -7465,9 +7494,9 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
int il) {
|
int il) {
|
||||||
const int64_t n_ctx = cparams.n_ctx;
|
const int64_t n_ctx = cparams.n_ctx;
|
||||||
const int64_t n_head = hparams.n_head;
|
const int64_t n_head = hparams.n_head;
|
||||||
const int64_t n_head_kv = hparams.n_head_kv;
|
const int64_t n_head_kv = hparams.n_head_kv_l(il);
|
||||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
@ -7619,9 +7648,7 @@ struct llm_build_context {
|
|||||||
const int64_t n_head;
|
const int64_t n_head;
|
||||||
const int64_t n_head_kv;
|
const int64_t n_head_kv;
|
||||||
const int64_t n_embd_head_k;
|
const int64_t n_embd_head_k;
|
||||||
const int64_t n_embd_k_gqa;
|
|
||||||
const int64_t n_embd_head_v;
|
const int64_t n_embd_head_v;
|
||||||
const int64_t n_embd_v_gqa;
|
|
||||||
const int64_t n_expert;
|
const int64_t n_expert;
|
||||||
const int64_t n_expert_used;
|
const int64_t n_expert_used;
|
||||||
|
|
||||||
@ -7673,9 +7700,7 @@ struct llm_build_context {
|
|||||||
n_head (hparams.n_head),
|
n_head (hparams.n_head),
|
||||||
n_head_kv (hparams.n_head_kv),
|
n_head_kv (hparams.n_head_kv),
|
||||||
n_embd_head_k (hparams.n_embd_head_k),
|
n_embd_head_k (hparams.n_embd_head_k),
|
||||||
n_embd_k_gqa (hparams.n_embd_k_gqa()),
|
|
||||||
n_embd_head_v (hparams.n_embd_head_v),
|
n_embd_head_v (hparams.n_embd_head_v),
|
||||||
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
|
||||||
n_expert (hparams.n_expert),
|
n_expert (hparams.n_expert),
|
||||||
n_expert_used (hparams.n_expert_used),
|
n_expert_used (hparams.n_expert_used),
|
||||||
freq_base (cparams.rope_freq_base),
|
freq_base (cparams.rope_freq_base),
|
||||||
@ -7746,9 +7771,9 @@ struct llm_build_context {
|
|||||||
// we rotate only the first n_rot dimensions
|
// we rotate only the first n_rot dimensions
|
||||||
ggml_rope_ext_inplace(ctx0,
|
ggml_rope_ext_inplace(ctx0,
|
||||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||||
n_embd_head_k, n_head_kv, n_ctx,
|
n_embd_head_k, hparams.n_head_kv_l(il), n_ctx,
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa(il)),
|
||||||
0),
|
0),
|
||||||
lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
@ -7777,6 +7802,9 @@ struct llm_build_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
|
int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
|
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||||
n_embd_k_gqa, nm,
|
n_embd_k_gqa, nm,
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||||
@ -11014,8 +11042,8 @@ struct llm_build_context {
|
|||||||
struct ggml_tensor * state_seq = build_inp_s_seq();
|
struct ggml_tensor * state_seq = build_inp_s_seq();
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(), rs_self.size);
|
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il), rs_self.size);
|
||||||
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(), rs_self.size);
|
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(il), rs_self.size);
|
||||||
|
|
||||||
// copy states
|
// copy states
|
||||||
{
|
{
|
||||||
@ -16452,7 +16480,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
}
|
}
|
||||||
ctx->backends.push_back(ctx->backend_cpu);
|
ctx->backends.push_back(ctx->backend_cpu);
|
||||||
|
|
||||||
if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.n_ctx, cparams.n_seq_max, cparams.offload_kqv)) {
|
if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) {
|
||||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -17282,7 +17310,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
|
|||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
|
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
|
||||||
@ -17434,7 +17462,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
|||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
size_t kv_buf_size;
|
size_t kv_buf_size;
|
||||||
@ -17627,7 +17655,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
|
|||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||||
@ -17713,7 +17741,7 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
|||||||
|
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
// Write the layer count
|
// Write the layer count
|
||||||
@ -17843,7 +17871,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
|||||||
// Sanity check model compatibility
|
// Sanity check model compatibility
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
if (n_layer != n_layer_ref) {
|
if (n_layer != n_layer_ref) {
|
||||||
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
|
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
|
||||||
|
Loading…
Reference in New Issue
Block a user