diff --git a/convert-llama-h5-to-gguf.py b/convert-llama-h5-to-gguf.py index a2b3f9a30..1a96a5426 100644 --- a/convert-llama-h5-to-gguf.py +++ b/convert-llama-h5-to-gguf.py @@ -18,11 +18,15 @@ NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' # reverse HF permute back to original pth layout # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py + + def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: - if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + .swapaxes(1, 2) + .reshape(weights.shape)) + def count_model_parts(dir_model: str) -> int: num_parts = 0 @@ -34,6 +38,7 @@ def count_model_parts(dir_model: str) -> int: print("gguf: found " + str(num_parts) + " model parts") return num_parts + if len(sys.argv) < 3: print("Usage: convert-h5-to-ggml.py dir-model ftype\n") print(" ftype == 0 -> float32") @@ -74,12 +79,11 @@ if hparams["architectures"][0] != "LlamaForCausalLM": # get number of model parts num_parts = count_model_parts(dir_model) -gguf_writer = gguf.GGUFWriter.open(fname_out) +gguf_writer = gguf.GGUFWriter(fname_out, architecture="llama") print("gguf: get model metadata") -llm_arch = "llama" block_count = hparams["num_hidden_layers"] head_count = hparams["num_attention_heads"] @@ -91,7 +95,7 @@ else: if "_name_or_path" in hparams: hf_repo = hparams["_name_or_path"] else: - hf_repo="" + hf_repo = "" if "max_sequence_length" in hparams: ctx_length = hparams["max_sequence_length"] @@ -102,19 +106,19 @@ else: sys.exit() -gguf_writer.add_architecture(llm_arch) +gguf_writer.add_architecture() gguf_writer.add_name(last_dir) gguf_writer.add_file_type("All tensors F32" if ftype == 0 else "Most tensors F16, some F32") gguf_writer.add_source_hf_repo(hf_repo) -gguf_writer.add_tensor_data_layout(llm_arch, "Meta AI original pth") -gguf_writer.add_context_length(llm_arch, ctx_length) -gguf_writer.add_embedding_length(llm_arch, hparams["hidden_size"]) -gguf_writer.add_block_count(llm_arch, block_count) -gguf_writer.add_feed_forward_length(llm_arch, hparams["intermediate_size"]) -gguf_writer.add_rope_dimension_count(llm_arch, hparams["hidden_size"] // hparams["num_attention_heads"]) -gguf_writer.add_head_count(llm_arch, head_count) -gguf_writer.add_head_count_kv(llm_arch, head_count_kv) -gguf_writer.add_layer_norm_rms_eps(llm_arch, hparams["rms_norm_eps"]) +gguf_writer.add_tensor_data_layout("Meta AI original pth") +gguf_writer.add_context_length(ctx_length) +gguf_writer.add_embedding_length(hparams["hidden_size"]) +gguf_writer.add_block_count(block_count) +gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) +gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) +gguf_writer.add_head_count(head_count) +gguf_writer.add_head_count_kv(head_count_kv) +gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) # TOKENIZATION @@ -136,19 +140,23 @@ if Path(dir_model + "/tokenizer.model").is_file(): score: float piece = tokenizer.id_to_piece(i) - text = piece.encode("utf-8") + text = piece.encode("utf-8") score = tokenizer.get_score(i) - toktype = 1 # defualt to normal token type - if tokenizer.is_unknown(i): toktype = 2 - if tokenizer.is_control(i): toktype = 3 + toktype = 1 # defualt to normal token type + if tokenizer.is_unknown(i): + toktype = 2 + if tokenizer.is_control(i): + toktype = 3 # TODO: How to determinate if a token is user defined? # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto # if tokenizer.is_user_defined(i): toktype = 4 - if tokenizer.is_unused(i): toktype = 5 - if tokenizer.is_byte(i): toktype = 6 + if tokenizer.is_unused(i): + toktype = 5 + if tokenizer.is_byte(i): + toktype = 6 tokens.append(text) scores.append(score) @@ -212,7 +220,7 @@ else: ) for part_name in part_names: - print("gguf: loading model part '"+ part_name + "'") + print("gguf: loading model part '" + part_name + "'") model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") for name in model_part.keys(): @@ -238,11 +246,12 @@ for part_name in part_names: elif name.endswith(".bias") and name[:-5] in tensor_map: name = tensor_map[name[:-5]] + ".bias" else: - print( "Can not map tensor '" + name + "'" ) + print("Can not map tensor '" + name + "'") sys.exit() n_dims = len(data.shape) data_dtype = data.dtype + old_dtype = data_dtype # if f32 desired, convert any float16 to float32 if ftype == 0 and data.dtype == np.float16: @@ -256,17 +265,19 @@ for part_name in part_names: if ftype == 1 and data.dtype == np.float32 and name.endswith(".weight") and n_dims == 2: data_dtype = np.float16 - data_nbytes = data.size * 2 if data_dtype == np.float16 else data.size * 4 + data = data.astype(data_dtype) - gguf_writer.add_tensor_info(name, data.shape, data_dtype, data_nbytes) + print(name + ", n_dims = " + n_dims + ", " + str(old_dtype) + " --> " + str(data.dtype)) + + gguf_writer.add_tensor(name, data) print("gguf: write header") gguf_writer.write_header_to_file() print("gguf: write metadata") gguf_writer.write_kv_data_to_file() -print("gguf: write tensor metadata") -gguf_writer.write_ti_data_to_file() +print("gguf: write tensors") +gguf_writer.write_tensors_to_file() # tensor data print("gguf: convert and write tensor data") @@ -279,7 +290,7 @@ else: ) for part_name in part_names: - print("gguf: loading model part '"+ part_name + "'") + print("gguf: loading model part '" + part_name + "'") model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") for name in model_part.keys(): @@ -307,7 +318,7 @@ for part_name in part_names: elif name.endswith(".bias") and name[:-5] in tensor_map: name = tensor_map[name[:-5]] + ".bias" else: - print( "Can not map tensor '" + name + "'" ) + print("Can not map tensor '" + name + "'") sys.exit() n_dims = len(data.shape) @@ -325,8 +336,6 @@ for part_name in part_names: if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: data = data.astype(np.float16) - print(name + ", shape " + str(len(data.shape)) + ", " + str(old_dtype) + " --> " + str(data.dtype)) - gguf_writer.write_tensor_to_file(data) gguf_writer.close() diff --git a/gguf.py b/gguf.py index e7f6f0ac8..5bbba3ec7 100644 --- a/gguf.py +++ b/gguf.py @@ -1,11 +1,7 @@ -"""TODOs -1. Implement writers for known architectures, LLaMA in particular. -2. Add docstrings from the format specs. -3. After development is done, Convert it to a proper pip-installable Python package, and possibly move it to its own repo under ggml-org. -""" - +import shutil import sys import struct +import tempfile import numpy as np from enum import IntEnum @@ -15,152 +11,153 @@ from typing import Any, IO, List # constants # -GGUF_MAGIC = 0x47475546 -GGUF_VERSION = 1 +GGUF_MAGIC = 0x47475546 +GGUF_VERSION = 1 GGUF_DEFAULT_ALIGNMENT = 32 # general -KEY_GENERAL_ARCHITECTURE = "general.architecture" +KEY_GENERAL_ARCHITECTURE = "general.architecture" KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version" -KEY_GENERAL_ALIGNMENT = "general.alignment" -KEY_GENERAL_NAME = "general.name" -KEY_GENERAL_AUTHOR = "general.author" -KEY_GENERAL_URL = "general.url" -KEY_GENERAL_DESCRIPTION = "general.description" -KEY_GENERAL_FILE_TYPE = "general.file_type" -KEY_GENERAL_LICENSE = "general.license" -KEY_GENERAL_SOURCE_URL = "general.source.url" -KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" +KEY_GENERAL_ALIGNMENT = "general.alignment" +KEY_GENERAL_NAME = "general.name" +KEY_GENERAL_AUTHOR = "general.author" +KEY_GENERAL_URL = "general.url" +KEY_GENERAL_DESCRIPTION = "general.description" +KEY_GENERAL_FILE_TYPE = "general.file_type" +KEY_GENERAL_LICENSE = "general.license" +KEY_GENERAL_SOURCE_URL = "general.source.url" +KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" # LLM -KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length" -KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length" -KEY_LLM_BLOCK_COUNT = "{llm}.block_count" -KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length" -KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual" -KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout" +KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length" +KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length" +KEY_LLM_BLOCK_COUNT = "{llm}.block_count" +KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length" +KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual" +KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout" # attention -KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count" -KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv" -KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias" -KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv" -KEY_ATTENTION_LAYERNORM_EPS = "{llm}.attention.layer_norm_epsilon" -KEY_ATTENTION_LAYERNORM_RMS_EPS = "{llm}.attention.layer_norm_rms_epsilon" +KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count" +KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv" +KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias" +KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv" +KEY_ATTENTION_LAYERNORM_EPS = "{llm}.attention.layer_norm_epsilon" +KEY_ATTENTION_LAYERNORM_RMS_EPS = "{llm}.attention.layer_norm_rms_epsilon" # RoPE -KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count" -KEY_ROPE_SCALE = "{llm}.rope.scale" +KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count" +KEY_ROPE_SCALE = "{llm}.rope.scale" # tokenization -KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" -KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens" +KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" +KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens" KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type" -KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores" -KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges" -KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id" -KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id" -KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id" -KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id" -KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id" -KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json" -KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" +KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores" +KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges" +KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id" +KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id" +KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id" +KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id" +KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id" +KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json" +KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" # # recommended mapping of model tensor names for storage in gguf # -def get_tensor_name_map(n_blocks : int): + +def get_tensor_name_map(n_blocks: int): tensor_map = {} # Token embeddings mapped_to = "token_embd" - tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox - tensor_map["transformer.wte"] = mapped_to # gpt2 mpt - tensor_map["transformer.word_embeddings"] = mapped_to # falcon - tensor_map["model.embed_tokens"] = mapped_to # llama-hf - tensor_map["tok_embeddings"] = mapped_to # llama-pth + tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox + tensor_map["transformer.wte"] = mapped_to # gpt2 mpt + tensor_map["transformer.word_embeddings"] = mapped_to # falcon + tensor_map["model.embed_tokens"] = mapped_to # llama-hf + tensor_map["tok_embeddings"] = mapped_to # llama-pth # Position embeddings mapped_to = "pos_embd" - tensor_map["transformer.wpe"] = mapped_to # gpt2 + tensor_map["transformer.wpe"] = mapped_to # gpt2 # Output norm mapped_to = "output_norm" - tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox - tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon - tensor_map["transformer.norm_f"] = mapped_to # mpt - tensor_map["model.norm"] = mapped_to # llama-hf - tensor_map["norm"] = mapped_to # llama-pth + tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox + tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon + tensor_map["transformer.norm_f"] = mapped_to # mpt + tensor_map["model.norm"] = mapped_to # llama-hf + tensor_map["norm"] = mapped_to # llama-pth # Output mapped_to = "output" - tensor_map["embed_out"] = mapped_to # gptneox - tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf - tensor_map["output"] = mapped_to # llama-pth + tensor_map["embed_out"] = mapped_to # gptneox + tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf + tensor_map["output"] = mapped_to # llama-pth # Attention and fee-forward layer blocks - for i in range(0,n_blocks): + for i in range(0, n_blocks): # Attention norm mapped_to = "blk."+str(i)+".attn_norm" - tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b - tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b - tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth + tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b + tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b + tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth # Attention norm 2 mapped_to = "blk."+str(i)+".attn_norm_2" - tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b + tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b # Attention query-key-value mapped_to = "blk."+str(i)+".attn_qkv" - tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon + tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon # Attention query mapped_to = "blk."+str(i)+".attn_q" - tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth + tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth # Attention key mapped_to = "blk."+str(i)+".attn_k" - tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth + tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth # Attention value mapped_to = "blk."+str(i)+".attn_v" - tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth + tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth # Attention output mapped_to = "blk."+str(i)+".attn_output" - tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth + tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth # Feed-forward norm mapped_to = "blk."+str(i)+".ffn_norm" - tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt - tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth + tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt + tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth # Feed-forward up mapped_to = "blk."+str(i)+".ffn_up" - tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth + tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth # Feed-forward gate mapped_to = "blk."+str(i)+".ffn_gate" - tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth + tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth # Feed-forward down mapped_to = "blk."+str(i)+".ffn_down" - tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth + tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth return tensor_map @@ -168,22 +165,23 @@ def get_tensor_name_map(n_blocks : int): # implementation # + class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 class GGUFValueType(IntEnum): - UINT8 = 0 - INT8 = 1 - UINT16 = 2 - INT16 = 3 - UINT32 = 4 - INT32 = 5 + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 FLOAT32 = 6 - BOOL = 7 - STRING = 8 - ARRAY = 9 + BOOL = 7 + STRING = 8 + ARRAY = 9 @staticmethod def get_type(val): @@ -203,8 +201,9 @@ class GGUFValueType(IntEnum): class GGUFWriter: - def __init__(self, fout: IO): - self.fout = fout + def __init__(self, path: str, architecture: str): + self.fout = open(path, "wb") + self.arch = architecture self.offset_tensor = 0 self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.kv_data = b"" @@ -228,11 +227,6 @@ class GGUFWriter: self.fout.write(self.ti_data) self.flush() - @classmethod - def open(cls, path: str) -> "GGUFWriter": - f = open(path, "wb") - return cls(f) - def add_key(self, key: str): self.add_val(key, GGUFValueType.STRING, add_vtype=False) @@ -269,7 +263,8 @@ class GGUFWriter: self.add_val(val, GGUFValueType.BOOL) def add_string(self, key: str, val: str): - if len(val) == 0: return + if len(val) == 0: + return self.add_key(key) self.add_val(val, GGUFValueType.STRING) @@ -323,6 +318,8 @@ class GGUFWriter: return ((x + n - 1) // n) * n def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int): + assert tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now" + encoded_name = name.encode("utf8") self.ti_data += struct.pack("