gguf : start write tensor info

This commit is contained in:
M. Yusuf Sarıgöz 2023-07-27 10:32:31 +03:00
parent 8332d26123
commit af1c9966c8
2 changed files with 23 additions and 9 deletions

View File

@ -1,5 +1,6 @@
GGUF_MAGIC = 0x47475546 GGUF_MAGIC = 0x47475546
GGUF_VERSION = 1 GGUF_VERSION = 1
GGUF_DEFAULT_ALIGNMENT = 32
# general # general
KEY_GENERAL_ARCHITECTURE = "general.architecture" KEY_GENERAL_ARCHITECTURE = "general.architecture"

31
gguf.py
View File

@ -7,7 +7,7 @@
import struct import struct
from enum import IntEnum from enum import IntEnum
from typing import List, Any from typing import List, Any, Sequence
import constants import constants
@ -57,14 +57,15 @@ class GGUFValueType(IntEnum):
class GGUFWriter: class GGUFWriter:
def __init__(self, buffered_writer): def __init__(self, fout):
self.buffered_writer = buffered_writer self.fout = fout
self.offset_tensor = 0
def write_header(self, tensor_count: int, metadata_kv_count: int): def write_header(self, tensor_count: int, metadata_kv_count: int):
self.buffered_writer.write(struct.pack("<I", constants.GGUF_MAGIC)) self.fout.write(struct.pack("<I", constants.GGUF_MAGIC))
self.buffered_writer.write(struct.pack("<I", constants.GGUF_VERSION)) self.fout.write(struct.pack("<I", constants.GGUF_VERSION))
self.buffered_writer.write(struct.pack("<I", tensor_count)) self.fout.write(struct.pack("<I", tensor_count))
self.buffered_writer.write(struct.pack("<I", metadata_kv_count)) self.fout.write(struct.pack("<I", metadata_kv_count))
@classmethod @classmethod
def open(cls, path: str) -> "GGUFWriter": def open(cls, path: str) -> "GGUFWriter":
@ -150,11 +151,23 @@ class GGUFWriter:
else: else:
raise ValueError("Invalid GGUF metadata value type") raise ValueError("Invalid GGUF metadata value type")
def write_tensor_info(self, name: str, shape: Sequence[int], dtype: GGMLQuantizationType):
self.write_value(name, GGUFValueType.STRING)
n_dims = len(shape)
self.write_value(n_dims, GGUFValueType.INT32)
for i in range(n_dims):
self.write_value(shape[n_dims - 1 - i], GGUFValueType.INT32)
self.fout.write(struct.pack("<Q", self.offset_tensor))
# TODO: update offset with alignment
# probably we need a dict as a class attribute to hold tensor data while writing
def flush(self): def flush(self):
self.buffered_writer.flush() self.fout.flush()
def close(self): def close(self):
self.buffered_writer.close() self.fout.close()
def write_architecture(self, architecture: str): def write_architecture(self, architecture: str):
self.write_string(constants.KEY_GENERAL_ARCHITECTURE, self.write_string(constants.KEY_GENERAL_ARCHITECTURE,