From 42cc04d11d24b7b41e19ecaa38b46350faca141b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Yusuf=20Sar=C4=B1g=C3=B6z?= Date: Thu, 10 Aug 2023 18:49:08 +0300 Subject: [PATCH] gguf : calculate n_mult --- gguf-llama.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/gguf-llama.cpp b/gguf-llama.cpp index d385ef79a..b88a2d8bf 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -514,16 +514,30 @@ struct ggml_context * ctx_data = NULL; return gguf_get_arr_n(gguf_ctx, i); } + int find_n_mult(const int n_ff, const int n_embd) { + int n_mults[3] = {8192, 1, -1}; + for (int i = 0; i < 3; ++i) { + int calc_ff = (((8 * n_embd) / 3 + n_mults[i] - 1) / n_mults[i]) * n_mults[i]; + if (calc_ff == n_ff) { + return n_mults[i]; + } + } + + throw std::runtime_error(format("failed to find n_mult for n_ff = %d and n_emb = %d\n", n_ff, n_embd)); + } + void read_hparams() { // TODO make keysconstants in header // TODO: read all hparams from file hparams.n_vocab = read_n_vocab(); + hparams.n_ctx = read_u32("llama.context_length"); hparams.n_embd = read_u32("llama.embedding_length"); - //hparams.n_mult = file.read_u32(); + uint32_t n_ff = read_u32("llama.feed_forward_length"); + hparams.n_mult = find_n_mult(n_ff, hparams.n_embd); hparams.n_head = read_u32("llama.attention.head_count"); hparams.n_layer = read_u32("llama.layer_count"); - //hparams.n_rot = file.read_u32(); + hparams.n_rot = hparams.n_embd / hparams.n_head; //hparams.ftype = (enum llama_ftype) file.read_u32(); // LLaMAv2 @@ -568,7 +582,7 @@ struct ggml_context * ctx_data = NULL; for (uint32_t j = 0; j < n_dims; ++j) { tensor.ne[j] = cur->ne[j]; } - + if (n_dims < 1 || n_dims > 2) { throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name, n_dims)); }