mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-22 09:39:08 +01:00
gguf : calculate n_mult
This commit is contained in:
parent
22de6c5c4c
commit
42cc04d11d
@ -514,16 +514,30 @@ struct ggml_context * ctx_data = NULL;
|
||||
return gguf_get_arr_n(gguf_ctx, i);
|
||||
}
|
||||
|
||||
int find_n_mult(const int n_ff, const int n_embd) {
|
||||
int n_mults[3] = {8192, 1, -1};
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
int calc_ff = (((8 * n_embd) / 3 + n_mults[i] - 1) / n_mults[i]) * n_mults[i];
|
||||
if (calc_ff == n_ff) {
|
||||
return n_mults[i];
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error(format("failed to find n_mult for n_ff = %d and n_emb = %d\n", n_ff, n_embd));
|
||||
}
|
||||
|
||||
void read_hparams() {
|
||||
|
||||
// TODO make keysconstants in header
|
||||
// TODO: read all hparams from file
|
||||
hparams.n_vocab = read_n_vocab();
|
||||
hparams.n_ctx = read_u32("llama.context_length");
|
||||
hparams.n_embd = read_u32("llama.embedding_length");
|
||||
//hparams.n_mult = file.read_u32();
|
||||
uint32_t n_ff = read_u32("llama.feed_forward_length");
|
||||
hparams.n_mult = find_n_mult(n_ff, hparams.n_embd);
|
||||
hparams.n_head = read_u32("llama.attention.head_count");
|
||||
hparams.n_layer = read_u32("llama.layer_count");
|
||||
//hparams.n_rot = file.read_u32();
|
||||
hparams.n_rot = hparams.n_embd / hparams.n_head;
|
||||
//hparams.ftype = (enum llama_ftype) file.read_u32();
|
||||
|
||||
// LLaMAv2
|
||||
|
Loading…
Reference in New Issue
Block a user