mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
llama : use std::array for per-layer hparams
This commit is contained in:
parent
e3e33c0cbc
commit
29ab5a0ed1
309
src/llama.cpp
309
src/llama.cpp
@ -105,6 +105,7 @@
|
||||
#endif
|
||||
|
||||
#define LLAMA_MAX_NODES 8192
|
||||
#define LLAMA_MAX_LAYERS 256
|
||||
#define LLAMA_MAX_EXPERTS 160
|
||||
|
||||
//
|
||||
@ -2101,21 +2102,17 @@ struct llama_hparams {
|
||||
uint32_t n_vocab;
|
||||
uint32_t n_ctx_train; // context size the model was trained on
|
||||
uint32_t n_embd;
|
||||
uint32_t n_head;
|
||||
uint32_t n_head_kv;
|
||||
uint32_t n_layer;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_ff;
|
||||
uint32_t n_expert = 0;
|
||||
uint32_t n_expert_used = 0;
|
||||
uint32_t n_vocab_type = 0; // for BERT-style token types
|
||||
|
||||
// TODO: find a more compact way to add more per-layer hyper-parameters
|
||||
std::vector<int32_t> n_head_vec;
|
||||
std::vector<int32_t> n_head_kv_vec;
|
||||
std::vector<int32_t> n_ff_vec;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
||||
|
||||
uint32_t n_layer_dense_lead = 0;
|
||||
uint32_t n_lora_q = 0;
|
||||
@ -2160,19 +2157,16 @@ struct llama_hparams {
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
if (this->n_ctx_train != other.n_ctx_train) return true;
|
||||
if (this->n_embd != other.n_embd) return true;
|
||||
if (this->n_head != other.n_head) return true;
|
||||
if (this->n_head_kv != other.n_head_kv) return true;
|
||||
if (this->n_layer != other.n_layer) return true;
|
||||
if (this->n_rot != other.n_rot) return true;
|
||||
if (this->n_embd_head_k != other.n_embd_head_k) return true;
|
||||
if (this->n_embd_head_v != other.n_embd_head_v) return true;
|
||||
if (this->n_ff != other.n_ff) return true;
|
||||
if (this->n_expert != other.n_expert) return true;
|
||||
if (this->n_expert_used != other.n_expert_used) return true;
|
||||
|
||||
if (this->n_head_vec != other.n_head_vec) return true;
|
||||
if (this->n_head_kv_vec != other.n_head_kv_vec) return true;
|
||||
if (this->n_ff_vec != other.n_ff_vec) return true;
|
||||
if (this->n_head_arr != other.n_head_arr) return true;
|
||||
if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
|
||||
if (this->n_ff_arr != other.n_ff_arr) return true;
|
||||
|
||||
if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
|
||||
if (this->n_lora_q != other.n_lora_q) return true;
|
||||
@ -2202,53 +2196,53 @@ struct llama_hparams {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: deduplicate per-layer getters
|
||||
uint32_t n_head_l(uint32_t layer) const {
|
||||
if (layer < n_head_vec.size()) {
|
||||
int32_t n_h_l = n_head_vec[layer];
|
||||
// TODO: what should happen when it's negative?
|
||||
GGML_ASSERT(n_h_l >= 0);
|
||||
return n_h_l;
|
||||
uint32_t n_head(uint32_t il = 0) const {
|
||||
if (il < n_layer) {
|
||||
return n_head_arr[il];
|
||||
}
|
||||
return n_head;
|
||||
|
||||
GGML_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t n_head_kv_l(uint32_t layer) const {
|
||||
if (layer < n_head_kv_vec.size()) {
|
||||
int32_t n_hkv_l = n_head_kv_vec[layer];
|
||||
// TODO: what should happen when it's negative?
|
||||
GGML_ASSERT(n_hkv_l >= 0);
|
||||
return n_hkv_l;
|
||||
uint32_t n_head_kv(uint32_t il = 0) const {
|
||||
if (il < n_layer) {
|
||||
return n_head_kv_arr[il];
|
||||
}
|
||||
return n_head_kv;
|
||||
|
||||
GGML_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t n_ff_l(uint32_t layer) const {
|
||||
if (layer < n_ff_vec.size()) {
|
||||
int32_t n_f_l = n_ff_vec[layer];
|
||||
// TODO: what should happen when it's negative?
|
||||
GGML_ASSERT(n_f_l >= 0);
|
||||
return n_f_l;
|
||||
uint32_t n_ff(uint32_t il = 0) const {
|
||||
if (il < n_layer) {
|
||||
return n_ff_arr[il];
|
||||
}
|
||||
return n_ff;
|
||||
|
||||
GGML_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t n_gqa(uint32_t layer = 0) const {
|
||||
uint32_t n_head_kv = n_head_kv_l(layer);
|
||||
uint32_t n_head = n_head_l(layer);
|
||||
uint32_t n_gqa(uint32_t il = 0) const {
|
||||
const uint32_t n_head = this->n_head(il);
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
if (n_head_kv == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return n_head/n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t n_embd_k_gqa(uint32_t layer = 0) const { // dimension of key embeddings across all k-v heads
|
||||
uint32_t n_head_kv = n_head_kv_l(layer);
|
||||
uint32_t n_embd_k_gqa(uint32_t il = 0) const { // dimension of key embeddings across all k-v heads
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
return n_embd_head_k * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t n_embd_v_gqa(uint32_t layer = 0) const { // dimension of value embeddings across all k-v heads
|
||||
uint32_t n_head_kv = n_head_kv_l(layer);
|
||||
uint32_t n_embd_v_gqa(uint32_t il = 0) const { // dimension of value embeddings across all k-v heads
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
@ -2265,6 +2259,8 @@ struct llama_hparams {
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_batch;
|
||||
@ -3797,11 +3793,11 @@ struct llama_model_loader {
|
||||
struct GGUFMeta::ArrayInfo arr_info =
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
|
||||
|
||||
// TODO: allow ANY lossless cast
|
||||
// GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value)); break;
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT(
|
||||
(std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
default:
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
}
|
||||
@ -3812,8 +3808,38 @@ struct llama_model_loader {
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename T, size_t N_MAX>
|
||||
bool get_arr(const std::string & key, std::array<T, N_MAX> & result, const bool required = true) {
|
||||
const int kid = gguf_find_key(meta, key.c_str());
|
||||
|
||||
if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
struct GGUFMeta::ArrayInfo arr_info =
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
|
||||
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT(
|
||||
(std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
default:
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
}
|
||||
|
||||
GGML_ASSERT(arr_info.length <= N_MAX);
|
||||
|
||||
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool get_arr(const enum llm_kv kid, T& result, const bool required = true) {
|
||||
bool get_arr(const enum llm_kv kid, T & result, const bool required = true) {
|
||||
return get_arr(llm_kv(kid), result, required);
|
||||
}
|
||||
|
||||
@ -3838,6 +3864,50 @@ struct llama_model_loader {
|
||||
return get_key(llm_kv(kid), result, required);
|
||||
}
|
||||
|
||||
// get array of n <= N_MAX elements, or a single element repeated n times
|
||||
template<typename T, size_t N_MAX>
|
||||
bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, const bool required = true) {
|
||||
GGML_ASSERT(n <= N_MAX);
|
||||
|
||||
const int kid = gguf_find_key(meta, key.c_str());
|
||||
|
||||
if (kid < 0) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) {
|
||||
struct GGUFMeta::ArrayInfo arr_info =
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
|
||||
|
||||
if (n != arr_info.length) {
|
||||
throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
|
||||
}
|
||||
|
||||
return get_arr(key, result, required);
|
||||
} else {
|
||||
T value;
|
||||
|
||||
bool ok = get_key(key, value, required);
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n; i++) {
|
||||
result[i] = value;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true) {
|
||||
return get_key_or_arr(llm_kv(kid), result, n, required);
|
||||
}
|
||||
|
||||
std::string get_arch_name() const {
|
||||
return arch_name;
|
||||
}
|
||||
@ -4409,7 +4479,7 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
|
||||
|
||||
// get hparams kv
|
||||
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
|
||||
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
|
||||
|
||||
// everything past this point is not vocab-related
|
||||
if (hparams.vocab_only) {
|
||||
@ -4430,22 +4500,20 @@ static void llm_load_hparams(
|
||||
GGML_ASSERT(hparams.n_expert_used == 0);
|
||||
}
|
||||
|
||||
// per-layer or global values
|
||||
if (!ml.get_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_vec, false)) {
|
||||
ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
|
||||
}
|
||||
if (!ml.get_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_vec, false)) {
|
||||
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
|
||||
} else {
|
||||
hparams.n_head = hparams.n_head_vec[0];
|
||||
}
|
||||
// zero-out the per-layer hparams
|
||||
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
||||
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
||||
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
||||
|
||||
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
|
||||
|
||||
GGML_ASSERT(hparams.n_head() > 0);
|
||||
|
||||
// n_head_kv is optional, default to n_head
|
||||
hparams.n_head_kv = hparams.n_head;
|
||||
hparams.n_head_kv_arr = hparams.n_head_arr;
|
||||
|
||||
if (!ml.get_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_vec, false)) {
|
||||
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
|
||||
}
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false);
|
||||
|
||||
bool rope_finetuned = false;
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
|
||||
@ -4475,23 +4543,23 @@ static void llm_load_hparams(
|
||||
|
||||
// sanity check for n_rot (optional)
|
||||
{
|
||||
hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
|
||||
hparams.n_rot = hparams.n_embd / hparams.n_head();
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||
|
||||
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
|
||||
if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
|
||||
if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
|
||||
}
|
||||
}
|
||||
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
|
||||
// gpt-j n_rot = rotary_dim
|
||||
}
|
||||
|
||||
hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
|
||||
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
|
||||
|
||||
hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
|
||||
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||
|
||||
// arch-specific KVs
|
||||
@ -4516,7 +4584,7 @@ static void llm_load_hparams(
|
||||
case 40: model.type = e_model::MODEL_13B; break;
|
||||
case 48: model.type = e_model::MODEL_34B; break;
|
||||
case 60: model.type = e_model::MODEL_30B; break;
|
||||
case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break;
|
||||
case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
}
|
||||
@ -4685,7 +4753,7 @@ static void llm_load_hparams(
|
||||
switch (hparams.n_layer) {
|
||||
case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break;
|
||||
case 32: model.type = e_model::MODEL_7B; break;
|
||||
case 40: model.type = hparams.n_head == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break;
|
||||
case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break;
|
||||
case 80: model.type = e_model::MODEL_70B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
@ -4893,40 +4961,40 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
|
||||
switch (hparams.n_layer) {
|
||||
case 6:
|
||||
switch (hparams.n_ff) {
|
||||
switch (hparams.n_ff()) {
|
||||
case 512: model.type = e_model::MODEL_14M; break;
|
||||
case 2048: model.type = e_model::MODEL_70M; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
case 12:
|
||||
switch (hparams.n_ff) {
|
||||
switch (hparams.n_ff()) {
|
||||
case 3072: model.type = e_model::MODEL_160M; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
case 16:
|
||||
switch (hparams.n_ff) {
|
||||
switch (hparams.n_ff()) {
|
||||
case 8192: model.type = e_model::MODEL_1B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
case 24:
|
||||
switch (hparams.n_ff) {
|
||||
switch (hparams.n_ff()) {
|
||||
case 4096: model.type = e_model::MODEL_410M; break;
|
||||
case 8192: model.type = e_model::MODEL_1_4B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
case 32:
|
||||
switch (hparams.n_ff) {
|
||||
switch (hparams.n_ff()) {
|
||||
case 10240: model.type = e_model::MODEL_2_8B; break;
|
||||
case 16384: model.type = e_model::MODEL_6_9B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
case 36:
|
||||
switch (hparams.n_ff) {
|
||||
switch (hparams.n_ff()) {
|
||||
case 20480: model.type = e_model::MODEL_12B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
case 44:
|
||||
switch (hparams.n_ff) {
|
||||
switch (hparams.n_ff()) {
|
||||
case 24576: model.type = e_model::MODEL_20B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
@ -5491,6 +5559,35 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
|
||||
const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
|
||||
|
||||
auto print_f = [](std::function<uint32_t(uint32_t)> f, uint32_t n) {
|
||||
bool is_var = false;
|
||||
|
||||
std::vector<uint32_t> v;
|
||||
for (uint32_t i = 0; i < n; ++i) {
|
||||
v.push_back(f(i));
|
||||
if (v[i] != v[0]) {
|
||||
is_var = true;
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
if (is_var) {
|
||||
ss << "[";
|
||||
for (uint32_t i = 0; i < n; ++i) {
|
||||
ss << v[i];
|
||||
if (i < n - 1) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
} else {
|
||||
ss << v[0];
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
// hparams
|
||||
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
|
||||
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch));
|
||||
@ -5499,21 +5596,21 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
|
||||
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
||||
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
|
||||
LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
|
||||
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
|
||||
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
|
||||
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %u\n", __func__, hparams.n_embd_k_gqa());
|
||||
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %u\n", __func__, hparams.n_embd_v_gqa());
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
|
||||
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
|
||||
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
|
||||
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
|
||||
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
|
||||
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
|
||||
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
|
||||
@ -5722,13 +5819,13 @@ static bool llm_load_tensors(
|
||||
// create tensors for the weights
|
||||
{
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_head = (hparams.n_head == 0) ? 0 : n_embd / hparams.n_head;
|
||||
const int64_t n_embd_head = n_embd / hparams.n_head();
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const int64_t n_embd_gqa = n_embd_v_gqa;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
const int64_t n_vocab_type = hparams.n_vocab_type;
|
||||
const int64_t n_ff = hparams.n_ff;
|
||||
const int64_t n_ff = hparams.n_ff();
|
||||
const int64_t n_expert = hparams.n_expert;
|
||||
|
||||
if (n_expert > 0 && hparams.n_expert_used == 0) {
|
||||
@ -6249,8 +6346,8 @@ static bool llm_load_tensors(
|
||||
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
// optional q and k layernorms, present in StableLM 2 12B
|
||||
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head()}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv()}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
// optional FFN norm, not present in StableLM 2 12B which uses parallel residual
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
@ -6621,7 +6718,7 @@ static bool llm_load_tensors(
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
||||
|
||||
const int64_t n_ff = hparams.n_ff;
|
||||
const int64_t n_ff = hparams.n_ff();
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
@ -6634,10 +6731,10 @@ static bool llm_load_tensors(
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head()});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head(), n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
@ -6653,7 +6750,7 @@ static bool llm_load_tensors(
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
||||
|
||||
const int64_t n_ff = hparams.n_ff;
|
||||
const int64_t n_ff = hparams.n_ff();
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
@ -6666,10 +6763,10 @@ static bool llm_load_tensors(
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head()});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head(), n_embd});
|
||||
layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
@ -6818,8 +6915,8 @@ static bool llm_load_tensors(
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
if (n_layer >= 64){
|
||||
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head});
|
||||
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv});
|
||||
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head()});
|
||||
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv()});
|
||||
}
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
@ -6873,11 +6970,11 @@ static bool llm_load_tensors(
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
const int64_t n_head = hparams.n_head_l(i);
|
||||
const int64_t n_head_qkv = 2*hparams.n_head_kv_l(i) + n_head;
|
||||
const int64_t n_head = hparams.n_head(i);
|
||||
const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
|
||||
const int64_t n_ff = hparams.n_ff_l(i);
|
||||
const int64_t n_ff = hparams.n_ff(i);
|
||||
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
@ -7004,13 +7101,13 @@ static bool llm_load_tensors(
|
||||
|
||||
if (!is_lite) {
|
||||
layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
|
||||
layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k});
|
||||
layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head() * hparams.n_embd_head_k});
|
||||
} else {
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
|
||||
}
|
||||
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope});
|
||||
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd});
|
||||
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head() * (n_embd_head_qk_nope + hparams.n_embd_head_v)});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { hparams.n_head() * ( hparams.n_embd_head_v), n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
@ -7646,8 +7743,8 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head_l(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv_l(il);
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||
@ -7857,8 +7954,8 @@ struct llm_build_context {
|
||||
n_layer (hparams.n_layer),
|
||||
n_rot (hparams.n_rot),
|
||||
n_ctx (cparams.n_ctx),
|
||||
n_head (hparams.n_head),
|
||||
n_head_kv (hparams.n_head_kv),
|
||||
n_head (hparams.n_head()),
|
||||
n_head_kv (hparams.n_head_kv()),
|
||||
n_embd_head_k (hparams.n_embd_head_k),
|
||||
n_embd_k_gqa (hparams.n_embd_k_gqa()),
|
||||
n_embd_head_v (hparams.n_embd_head_v),
|
||||
@ -7926,7 +8023,7 @@ struct llm_build_context {
|
||||
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_head_kv = hparams.n_head_kv_l(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||
struct ggml_tensor * tmp =
|
||||
@ -11825,8 +11922,8 @@ struct llm_build_context {
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_head = hparams.n_head_l(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv_l(il);
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_head_qkv = 2*n_head_kv + n_head;
|
||||
|
||||
cur = inpL;
|
||||
|
Loading…
Reference in New Issue
Block a user