gguf.py : use custom alignment if present

This commit is contained in:
klosax 2023-08-07 13:51:26 +02:00 committed by GitHub
parent db5618ad99
commit 4357e692ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -64,6 +64,7 @@ class GGUFWriter:
def __init__(self, fout: IO): def __init__(self, fout: IO):
self.fout = fout self.fout = fout
self.offset_tensor = 0 self.offset_tensor = 0
self.data_alignment = constants.GGUF_DEFAULT_ALIGNMENT
self.kv_data = b"" self.kv_data = b""
self.kv_data_count = 0 self.kv_data_count = 0
self.ti_data = b"" self.ti_data = b""
@ -191,16 +192,17 @@ class GGUFWriter:
dtype = GGMLQuantizationType.F32 if tensor.dtype == np.float32 else GGMLQuantizationType.F16 dtype = GGMLQuantizationType.F32 if tensor.dtype == np.float32 else GGMLQuantizationType.F16
self.ti_data += struct.pack("<I", dtype) self.ti_data += struct.pack("<I", dtype)
self.ti_data += struct.pack("<Q", self.offset_tensor) self.ti_data += struct.pack("<Q", self.offset_tensor)
self.offset_tensor += GGUFWriter.ggml_pad(tensor.nbytes, constants.GGUF_DEFAULT_ALIGNMENT) self.offset_tensor += GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment)
self.ti_data_count += 1 self.ti_data_count += 1
def write_tensor_to_file(self, tensor: np.ndarray): def write_tensor_to_file(self, tensor: np.ndarray):
pad = GGUFWriter.ggml_pad(self.fout.tell(), constants.GGUF_DEFAULT_ALIGNMENT) - self.fout.tell() pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell()
if pad != 0: if pad != 0:
self.fout.write(bytes([0] * pad)) self.fout.write(bytes([0] * pad))
tensor.tofile(self.fout) tensor.tofile(self.fout)
pad = GGUFWriter.ggml_pad(tensor.nbytes, constants.GGUF_DEFAULT_ALIGNMENT) - tensor.nbytes
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
if pad != 0: if pad != 0:
self.fout.write(bytes([0] * pad)) self.fout.write(bytes([0] * pad))
@ -240,6 +242,7 @@ class GGUFWriter:
constants.KEY_GENERAL_QUANTIZATION_VERSION, quantization_version) constants.KEY_GENERAL_QUANTIZATION_VERSION, quantization_version)
def add_custom_alignment(self, alignment: int): def add_custom_alignment(self, alignment: int):
self.data_alignment = alignment
self.add_uint32(constants.KEY_GENERAL_ALIGNMENT, alignment) self.add_uint32(constants.KEY_GENERAL_ALIGNMENT, alignment)
def add_context_length(self, llm: str, length: int): def add_context_length(self, llm: str, length: int):