mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-19 08:20:10 +01:00
convert.py : n_head_kv optional and .gguf file extension
This commit is contained in:
parent
dd016cc246
commit
d646c4efce
41
convert.py
41
convert.py
@ -150,15 +150,20 @@ class Params:
|
|||||||
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
||||||
config = json.load(open(config_path))
|
config = json.load(open(config_path))
|
||||||
|
|
||||||
n_vocab = config["vocab_size"];
|
n_vocab = config["vocab_size"]
|
||||||
n_embd = config["hidden_size"];
|
n_embd = config["hidden_size"]
|
||||||
n_layer = config["num_hidden_layers"];
|
n_layer = config["num_hidden_layers"]
|
||||||
n_ff = config["intermediate_size"];
|
n_ff = config["intermediate_size"]
|
||||||
n_head = config["num_attention_heads"];
|
n_head = config["num_attention_heads"]
|
||||||
n_head_kv = config["num_key_value_heads"];
|
|
||||||
f_norm_eps = config["rms_norm_eps"];
|
|
||||||
|
|
||||||
n_mult = Params.find_n_mult(n_ff, n_embd);
|
if "num_key_value_heads" in config:
|
||||||
|
n_head_kv = config["num_key_value_heads"]
|
||||||
|
else:
|
||||||
|
n_head_kv = None
|
||||||
|
|
||||||
|
f_norm_eps = config["rms_norm_eps"]
|
||||||
|
|
||||||
|
n_mult = Params.find_n_mult(n_ff, n_embd)
|
||||||
|
|
||||||
if "max_sequence_length" in config:
|
if "max_sequence_length" in config:
|
||||||
n_ctx = config["max_sequence_length"]
|
n_ctx = config["max_sequence_length"]
|
||||||
@ -186,15 +191,15 @@ class Params:
|
|||||||
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
||||||
config = json.load(open(config_path))
|
config = json.load(open(config_path))
|
||||||
|
|
||||||
n_vocab = config["vocab_size"];
|
n_vocab = config["vocab_size"]
|
||||||
n_embd = config["dim"];
|
n_embd = config["dim"]
|
||||||
n_layer = config["n_layers"];
|
n_layer = config["n_layers"]
|
||||||
n_mult = config["multiple_of"];
|
n_mult = config["multiple_of"]
|
||||||
n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
|
n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
|
||||||
n_ff = -1;
|
n_ff = -1
|
||||||
n_head = config["n_heads"];
|
n_head = config["n_heads"]
|
||||||
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head;
|
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
|
||||||
f_norm_eps = config["norm_eps"];
|
f_norm_eps = config["norm_eps"]
|
||||||
|
|
||||||
if n_vocab == -1:
|
if n_vocab == -1:
|
||||||
n_vocab = model["tok_embeddings.weight"].shape[0]
|
n_vocab = model["tok_embeddings.weight"].shape[0]
|
||||||
@ -714,7 +719,7 @@ class OutputFile:
|
|||||||
self.gguf.add_feed_forward_length (params.n_ff)
|
self.gguf.add_feed_forward_length (params.n_ff)
|
||||||
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
|
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
|
||||||
self.gguf.add_head_count (params.n_head)
|
self.gguf.add_head_count (params.n_head)
|
||||||
self.gguf.add_head_count_kv (params.n_head_kv)
|
if params.n_head_kv is not None: self.gguf.add_head_count_kv(params.n_head_kv)
|
||||||
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
|
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
|
||||||
|
|
||||||
def add_meta_vocab(self, vocab: Vocab) -> None:
|
def add_meta_vocab(self, vocab: Vocab) -> None:
|
||||||
@ -934,7 +939,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
|
|||||||
GGMLFileType.AllF32: "f32",
|
GGMLFileType.AllF32: "f32",
|
||||||
GGMLFileType.MostlyF16: "f16",
|
GGMLFileType.MostlyF16: "f16",
|
||||||
}[file_type]
|
}[file_type]
|
||||||
ret = model_paths[0].parent / f"ggml-model-{namestr}.bin"
|
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
|
||||||
if ret in model_paths:
|
if ret in model_paths:
|
||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
f"Error: Default output path ({ret}) would overwrite the input. "
|
f"Error: Default output path ({ret}) would overwrite the input. "
|
||||||
|
Loading…
Reference in New Issue
Block a user