gguf : calculate n_mult

This commit is contained in:
M. Yusuf Sarıgöz 2023-08-10 18:49:08 +03:00
parent 22de6c5c4c
commit 42cc04d11d

View File

@ -514,16 +514,30 @@ struct ggml_context * ctx_data = NULL;
return gguf_get_arr_n(gguf_ctx, i);
}
int find_n_mult(const int n_ff, const int n_embd) {
int n_mults[3] = {8192, 1, -1};
for (int i = 0; i < 3; ++i) {
int calc_ff = (((8 * n_embd) / 3 + n_mults[i] - 1) / n_mults[i]) * n_mults[i];
if (calc_ff == n_ff) {
return n_mults[i];
}
}
throw std::runtime_error(format("failed to find n_mult for n_ff = %d and n_emb = %d\n", n_ff, n_embd));
}
void read_hparams() {
// TODO make keysconstants in header
// TODO: read all hparams from file
hparams.n_vocab = read_n_vocab();
hparams.n_ctx = read_u32("llama.context_length");
hparams.n_embd = read_u32("llama.embedding_length");
//hparams.n_mult = file.read_u32();
uint32_t n_ff = read_u32("llama.feed_forward_length");
hparams.n_mult = find_n_mult(n_ff, hparams.n_embd);
hparams.n_head = read_u32("llama.attention.head_count");
hparams.n_layer = read_u32("llama.layer_count");
//hparams.n_rot = file.read_u32();
hparams.n_rot = hparams.n_embd / hparams.n_head;
//hparams.ftype = (enum llama_ftype) file.read_u32();
// LLaMAv2
@ -568,7 +582,7 @@ struct ggml_context * ctx_data = NULL;
for (uint32_t j = 0; j < n_dims; ++j) {
tensor.ne[j] = cur->ne[j];
}
if (n_dims < 1 || n_dims > 2) {
throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name, n_dims));
}