mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-19 08:20:10 +01:00
gguf : style fixes in simple conversion script
This commit is contained in:
parent
2f8fc92d86
commit
22c61c5b45
@ -23,6 +23,7 @@ NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
|||||||
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
||||||
if n_kv_head is not None and n_head != n_kv_head:
|
if n_kv_head is not None and n_head != n_kv_head:
|
||||||
n_head //= n_kv_head
|
n_head //= n_kv_head
|
||||||
|
|
||||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
.swapaxes(1, 2)
|
.swapaxes(1, 2)
|
||||||
.reshape(weights.shape))
|
.reshape(weights.shape))
|
||||||
@ -30,12 +31,14 @@ def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] =
|
|||||||
|
|
||||||
def count_model_parts(dir_model: str) -> int:
|
def count_model_parts(dir_model: str) -> int:
|
||||||
num_parts = 0
|
num_parts = 0
|
||||||
|
|
||||||
for filename in os.listdir(dir_model):
|
for filename in os.listdir(dir_model):
|
||||||
if filename.startswith("pytorch_model-"):
|
if filename.startswith("pytorch_model-"):
|
||||||
num_parts += 1
|
num_parts += 1
|
||||||
|
|
||||||
if num_parts > 0:
|
if num_parts > 0:
|
||||||
print("gguf: found " + str(num_parts) + " model parts")
|
print("gguf: found " + str(num_parts) + " model parts")
|
||||||
|
|
||||||
return num_parts
|
return num_parts
|
||||||
|
|
||||||
|
|
||||||
@ -43,6 +46,7 @@ if len(sys.argv) < 3:
|
|||||||
print("Usage: convert-h5-to-ggml.py dir-model ftype\n")
|
print("Usage: convert-h5-to-ggml.py dir-model ftype\n")
|
||||||
print(" ftype == 0 -> float32")
|
print(" ftype == 0 -> float32")
|
||||||
print(" ftype == 1 -> float16")
|
print(" ftype == 1 -> float16")
|
||||||
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +58,8 @@ last_dir = os.path.basename(os.path.normpath(dir_model))
|
|||||||
# possible tensor data types
|
# possible tensor data types
|
||||||
# ftype == 0 -> float32
|
# ftype == 0 -> float32
|
||||||
# ftype == 1 -> float16
|
# ftype == 1 -> float16
|
||||||
#
|
|
||||||
|
|
||||||
# map from ftype to string
|
# map from ftype to string
|
||||||
ftype_str = ["f32", "f16"]
|
ftype_str = ["f32", "f16"]
|
||||||
|
|
||||||
@ -63,6 +68,7 @@ if len(sys.argv) > 2:
|
|||||||
ftype = int(sys.argv[2])
|
ftype = int(sys.argv[2])
|
||||||
if ftype < 0 or ftype > 1:
|
if ftype < 0 or ftype > 1:
|
||||||
print("Invalid ftype: " + str(ftype))
|
print("Invalid ftype: " + str(ftype))
|
||||||
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf"
|
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf"
|
||||||
@ -74,12 +80,13 @@ with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
|
|||||||
|
|
||||||
if hparams["architectures"][0] != "LlamaForCausalLM":
|
if hparams["architectures"][0] != "LlamaForCausalLM":
|
||||||
print("Model architecture not supported: " + hparams["architectures"][0])
|
print("Model architecture not supported: " + hparams["architectures"][0])
|
||||||
|
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
# get number of model parts
|
# get number of model parts
|
||||||
num_parts = count_model_parts(dir_model)
|
num_parts = count_model_parts(dir_model)
|
||||||
|
|
||||||
gguf_writer = gguf.GGUFWriter(fname_out, architecture="llama")
|
gguf_writer = gguf.GGUFWriter(fname_out, arch="llama")
|
||||||
|
|
||||||
|
|
||||||
print("gguf: get model metadata")
|
print("gguf: get model metadata")
|
||||||
@ -103,12 +110,12 @@ elif "max_position_embeddings" in hparams:
|
|||||||
ctx_length = hparams["max_position_embeddings"]
|
ctx_length = hparams["max_position_embeddings"]
|
||||||
else:
|
else:
|
||||||
print("gguf: can not find ctx length parameter.")
|
print("gguf: can not find ctx length parameter.")
|
||||||
|
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
gguf_writer.add_architecture()
|
gguf_writer.add_architecture()
|
||||||
gguf_writer.add_name(last_dir)
|
gguf_writer.add_name(last_dir)
|
||||||
gguf_writer.add_file_type("All tensors F32" if ftype == 0 else "Most tensors F16, some F32")
|
|
||||||
gguf_writer.add_source_hf_repo(hf_repo)
|
gguf_writer.add_source_hf_repo(hf_repo)
|
||||||
gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||||
gguf_writer.add_context_length(ctx_length)
|
gguf_writer.add_context_length(ctx_length)
|
||||||
@ -247,6 +254,7 @@ for part_name in part_names:
|
|||||||
name = tensor_map[name[:-5]] + ".bias"
|
name = tensor_map[name[:-5]] + ".bias"
|
||||||
else:
|
else:
|
||||||
print("Can not map tensor '" + name + "'")
|
print("Can not map tensor '" + name + "'")
|
||||||
|
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user