convert-llama-h5-to-gguf.py : fixes

This commit is contained in:
klosax 2023-08-14 11:14:24 +02:00 committed by GitHub
parent d753dfbcc8
commit a7d226f871
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
# HF llama --> gguf conversion, GQA/70b not supported # HF llama --> gguf conversion, GQA/70b not supported
import gguf import gguf
import gguf_tensor_map as tmap import gguf_namemap as tmap
import os import os
import sys import sys
import struct import struct
@ -79,14 +79,23 @@ gguf_writer = gguf.GGUFWriter.open(fname_out)
print("gguf: get model metadata") print("gguf: get model metadata")
llm_arch = "llama" llm_arch = "llama"
hf_repo = hparams["_name_or_path"]
head_count = hparams["num_attention_heads"]
head_count_kv = hparams["num_key_value_heads"]
block_count = hparams["num_hidden_layers"] block_count = hparams["num_hidden_layers"]
head_count = hparams["num_attention_heads"]
if "num_key_value_heads" in hparams:
head_count_kv = hparams["num_key_value_heads"]
else:
head_count_kv = head_count
if "_name_or_path" in hparams:
hf_repo = hparams["_name_or_path"]
else:
hf_repo=""
gguf_writer.add_name(last_dir)
gguf_writer.add_architecture(llm_arch) gguf_writer.add_architecture(llm_arch)
guff_writer.add_source_hf_repo(hf_repo) gguf_writer.add_name(last_dir)
gguf_writer.add_file_type( "All tensors F32" if ftype == 0 else "Most tensors F16, some F32")
gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_context_length(llm_arch, hparams["max_position_embeddings"]) gguf_writer.add_context_length(llm_arch, hparams["max_position_embeddings"])
gguf_writer.add_embedding_length(llm_arch, hparams["hidden_size"]) gguf_writer.add_embedding_length(llm_arch, hparams["hidden_size"])
gguf_writer.add_block_count(llm_arch, block_count) gguf_writer.add_block_count(llm_arch, block_count)
@ -173,7 +182,7 @@ if Path(dir_model + "/tokenizer.json").is_file():
# TENSORS # TENSORS
tensor_map = tmap.get_tensor_map(block_count) tensor_map = tmap.get_tensor_namemap(block_count)
# tensor info # tensor info
print("gguf: get tensor metadata") print("gguf: get tensor metadata")