diff --git a/constants.py b/constants.py index 7c7456403..3a97460e5 100644 --- a/constants.py +++ b/constants.py @@ -1,32 +1,32 @@ -GGUF_MAGIC = 0x47475546 +GGUF_MAGIC = 0x47475546 GGUF_VERSION = 1 # general -KEY_GENERAL_ARCHITECTURE = "general.architecture" +KEY_GENERAL_ARCHITECTURE = "general.architecture" KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version" -KEY_GENERAL_NAME = "general.name" -KEY_GENERAL_AUTHOR = "general.author" -KEY_GENERAL_URL = "general.url" -KEY_GENERAL_DESCRIPTION = "general.description" -KEY_GENERAL_FILE_TYPE = "general.file_type" -KEY_GENERAL_LICENSE = "general.license" -KEY_GENERAL_SOURCE_URL = "general.source.url" -KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" +KEY_GENERAL_NAME = "general.name" +KEY_GENERAL_AUTHOR = "general.author" +KEY_GENERAL_URL = "general.url" +KEY_GENERAL_DESCRIPTION = "general.description" +KEY_GENERAL_FILE_TYPE = "general.file_type" +KEY_GENERAL_LICENSE = "general.license" +KEY_GENERAL_SOURCE_URL = "general.source.url" +KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" # LLM -KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length" -KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length" -KEY_LLM_LAYER_COUNT = "{llm}.layer_count" -KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length" -KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual" -KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout" +KEY_LLM_CONTEXT_LENGTH = "{llm}.context_length" +KEY_LLM_EMBEDDING_LENGTH = "{llm}.embedding_length" +KEY_LLM_LAYER_COUNT = "{llm}.layer_count" +KEY_LLM_FEED_FORWARD_LENGTH = "{llm}.feed_forward_length" +KEY_LLM_USE_PARALLEL_RESIDUAL = "{llm}.use_parallel_residual" +KEY_LLM_TENSOR_DATA_LAYOUT = "{llm}.tensor_data_layout" # attention -KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count" -KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv" -KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias" -KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv" +KEY_ATTENTION_HEAD_COUNT = "{llm}.attention.head_count" +KEY_ATTENTION_HEAD_COUNT_KV = "{llm}.attention.head_count_kv" +KEY_ATTENTION_MAX_ALIBI_BIAS = "{llm}.attention.max_alibi_bias" +KEY_ATTENTION_CLAMP_KQV = "{llm}.attention.clamp_kqv" # RoPE -KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count" -KEY_ROPE_SCALE = "{llm}.rope.scale" +KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count" +KEY_ROPE_SCALE = "{llm}.rope.scale" diff --git a/gguf.py b/gguf.py index 991bbe2f3..764ae9a9d 100644 --- a/gguf.py +++ b/gguf.py @@ -6,14 +6,13 @@ """ import struct +import constants from enum import IntEnum from typing import List, Any -import constants - class GGMLQuantizationType(IntEnum): - F32 = 0 - F16 = 1 + F32 = 0 + F16 = 1 QR_0 = 2 Q4_1 = 3 # Q4_2 = 4 # support has been removed @@ -31,31 +30,30 @@ class GGMLQuantizationType(IntEnum): class GGUFValueType(IntEnum): - UINT8 = 0 - INT8 = 1 - UINT16 = 2 - INT16 = 3 - UINT32 = 4 - INT32 = 5 + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 FLOAT32 = 6 - BOOL = 7 - STRING = 8 - ARRAY = 9 + BOOL = 7 + STRING = 8 + ARRAY = 9 @staticmethod - def get_type(value): - if isinstance(value, str): + def get_type(val): + if isinstance(val, str): return GGUFValueType.STRING - elif isinstance(value, list): + elif isinstance(val, list): return GGUFValueType.ARRAY - elif isinstance(value, float): + elif isinstance(val, float): return GGUFValueType.FLOAT32 - elif isinstance(value, bool): + elif isinstance(val, bool): return GGUFValueType.BOOL else: return GGUFValueType.INT32 - class GGUFWriter: def __init__(self, buffered_writer): self.buffered_writer = buffered_writer @@ -72,81 +70,81 @@ class GGUFWriter: return cls(f) def write_key(self, key: str): - self.write_value(key, GGUFValueType.STRING) + self.write_val(key, GGUFValueType.STRING) - def write_uint8(self, key: str, value: int): + def write_uint8(self, key: str, val: int): self.write_key(key) - self.write_value(value, GGUFValueType.UINT8) + self.write_val(val, GGUFValueType.UINT8) - def write_int8(self, key: str, value: int): + def write_int8(self, key: str, val: int): self.write_key(key) - self.write_value(value, GGUFValueType.INT8) + self.write_val(val, GGUFValueType.INT8) - def write_uint16(self, key: str, value: int): + def write_uint16(self, key: str, val: int): self.write_key(key) - self.write_value(value, GGUFValueType.UINT16) + self.write_val(val, GGUFValueType.UINT16) - def write_int16(self, key: str, value: int): + def write_int16(self, key: str, val: int): self.write_key(key) - self.write_value(value, GGUFValueType.INT16) + self.write_val(val, GGUFValueType.INT16) - def write_uint32(self, key: str, value: int): + def write_uint32(self, key: str, val: int): self.write_key(key) - self.write(value, GGUFValueType.UINT32) + self.write_val(val, GGUFValueType.UINT32) - def write_int32(self, key: str, value: int): + def write_int32(self, key: str, val: int): self.write_key(key) - self.write_value(value, GGUFValueType.INT32) + self.write_val(val, GGUFValueType.INT32) - def write_float32(self, key: str, value: float): + def write_float32(self, key: str, val: float): self.write_key(key) - self.write_value(value, GGUFValueType.FLOAT32) + self.write_val(val, GGUFValueType.FLOAT32) - def write_bool(self, key: str, value: bool): + def write_bool(self, key: str, val: bool): self.write_key(key) - self.write_value(value, GGUFValueType.BOOL) + self.write_val(val, GGUFValueType.BOOL) - def write_string(self, key: str, value: str): + def write_string(self, key: str, val: str): self.write_key(key) - self.write_value(value, GGUFValueType.STRING) + self.write_val(val, GGUFValueType.STRING) - def write_array(self, key: str, value: list): - if not isinstance(value, list): + def write_array(self, key: str, val: list): + if not isinstance(val, list): raise ValueError("Value must be a list for array type") self.write_key(key) - self.write_value(value, GGUFValueType.ARRAY) + self.write_val(val, GGUFValueType.ARRAY) - def write_value(self: str, value: Any, value_type: GGUFValueType = None): - if value_type is None: - value_type = GGUFValueType.get_type(value) + def write_val(self: str, val: Any, vtype: GGUFValueType = None): + if vtype is None: + vtype = GGUFValueType.get_type(val) - self.buffered_writer.write(struct.pack("