gguf : get rid of n_mult, read n_ff from file

This commit is contained in:
M. Yusuf Sarıgöz 2023-08-11 23:50:38 +03:00
parent f44bbd3d88
commit e732423280

View File

@ -177,15 +177,12 @@ struct llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; // this is provided as user input?
uint32_t n_embd = 4096;
uint32_t n_mult = 256;
uint32_t n_head = 32;
uint32_t n_head_kv = 32;
uint32_t n_layer = 32;
uint32_t n_rot = 64;
uint32_t n_ff = 11008;
// LLaMAv2
// TODO: load from model data hparams
float f_ffn_mult = 1.0f;
float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS;
float rope_freq_base = 10000.0f;
@ -467,7 +464,7 @@ struct llama_load_tensors_map {
};
enum gguf_file_version {
GGUF_FILE_VERSION_V1
GGUF_FILE_VERSION_V1 = 1,
};
@ -490,6 +487,7 @@ struct ggml_context * ctx_data = NULL;
};
gguf_ctx = gguf_init_from_file(fname, params);
file_version = (enum gguf_file_version) gguf_get_version(gguf_ctx);
read_hparams();
read_vocab();
@ -505,6 +503,15 @@ struct ggml_context * ctx_data = NULL;
return gguf_get_val_u32(gguf_ctx, i);
}
float read_f32(const char * key) {
int i = gguf_find_key(gguf_ctx, key);
if (i == -1) {
throw std::runtime_error(format("cannot find param with key %s\n", key));
}
return gguf_get_val_f32(gguf_ctx, i);
}
int read_n_vocab() {
int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens");
if (i == -1) {
@ -514,18 +521,6 @@ struct ggml_context * ctx_data = NULL;
return gguf_get_arr_n(gguf_ctx, i);
}
int find_n_mult(const int n_ff, const int n_embd) {
int n_mults[3] = {8192, 1, -1};
for (int i = 0; i < 3; ++i) {
int calc_ff = (((8 * n_embd) / 3 + n_mults[i] - 1) / n_mults[i]) * n_mults[i];
if (calc_ff == n_ff) {
return n_mults[i];
}
}
throw std::runtime_error(format("failed to find n_mult for n_ff = %d and n_embd = %d\n", n_ff, n_embd));
}
void read_hparams() {
// TODO make keysconstants in header
@ -533,13 +528,11 @@ struct ggml_context * ctx_data = NULL;
hparams.n_vocab = read_n_vocab();
hparams.n_ctx = read_u32("llama.context_length");
hparams.n_embd = read_u32("llama.embedding_length");
uint32_t n_ff = read_u32("llama.feed_forward_length");
GGML_UNUSED(n_ff);
//hparams.n_mult = find_n_mult(n_ff, hparams.n_embd);
hparams.n_ff = read_u32("llama.feed_forward_length");
hparams.n_head = read_u32("llama.attention.head_count");
hparams.n_layer = read_u32("llama.layer_count");
hparams.n_rot = hparams.n_embd / hparams.n_head;
//hparams.ftype = (enum llama_ftype) file.read_u32();
hparams.n_rot = read_u32("llama.rope.dimension_count");
hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon");
// LLaMAv2
// hparams.n_head_kv = read_u32("llama.attention.head_count_kv");
@ -1125,6 +1118,7 @@ static void llama_model_load_internal(
bool vocab_only,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
GGML_UNUSED(rms_norm_eps); // TODO: update function signature to remove this
model.t_start_us = ggml_time_us();
@ -1137,9 +1131,6 @@ static void llama_model_load_internal(
auto & hparams = model.hparams;
// TODO: read from file
hparams.f_rms_norm_eps = rms_norm_eps;
{
switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_3B; break;
@ -1162,25 +1153,19 @@ static void llama_model_load_internal(
if (model.type == e_model::MODEL_65B && n_gqa == 8) {
fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
model.type = e_model::MODEL_70B;
hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
}
}
hparams.rope_freq_base = rope_freq_base;
hparams.rope_freq_scale = rope_freq_scale;
}
// ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199
const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3;
const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw;
const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
//const uint32_t n_ff = 28672;
const uint32_t n_ff = hparams.n_ff;
{
fprintf(stderr, "%s: format = %s\n", __func__, gguf_file_version_name(file_version));
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);