gguf : start write tensor info

This commit is contained in:
M. Yusuf Sarıgöz 2023-07-27 10:32:31 +03:00
parent 8332d26123
commit af1c9966c8
2 changed files with 23 additions and 9 deletions

View File

@ -1,5 +1,6 @@
GGUF_MAGIC = 0x47475546
GGUF_VERSION = 1
GGUF_DEFAULT_ALIGNMENT = 32
# general
KEY_GENERAL_ARCHITECTURE = "general.architecture"

31
gguf.py
View File

@ -7,7 +7,7 @@
import struct
from enum import IntEnum
from typing import List, Any
from typing import List, Any, Sequence
import constants
@ -57,14 +57,15 @@ class GGUFValueType(IntEnum):
class GGUFWriter:
def __init__(self, buffered_writer):
self.buffered_writer = buffered_writer
def __init__(self, fout):
self.fout = fout
self.offset_tensor = 0
def write_header(self, tensor_count: int, metadata_kv_count: int):
self.buffered_writer.write(struct.pack("<I", constants.GGUF_MAGIC))
self.buffered_writer.write(struct.pack("<I", constants.GGUF_VERSION))
self.buffered_writer.write(struct.pack("<I", tensor_count))
self.buffered_writer.write(struct.pack("<I", metadata_kv_count))
self.fout.write(struct.pack("<I", constants.GGUF_MAGIC))
self.fout.write(struct.pack("<I", constants.GGUF_VERSION))
self.fout.write(struct.pack("<I", tensor_count))
self.fout.write(struct.pack("<I", metadata_kv_count))
@classmethod
def open(cls, path: str) -> "GGUFWriter":
@ -150,11 +151,23 @@ class GGUFWriter:
else:
raise ValueError("Invalid GGUF metadata value type")
def write_tensor_info(self, name: str, shape: Sequence[int], dtype: GGMLQuantizationType):
self.write_value(name, GGUFValueType.STRING)
n_dims = len(shape)
self.write_value(n_dims, GGUFValueType.INT32)
for i in range(n_dims):
self.write_value(shape[n_dims - 1 - i], GGUFValueType.INT32)
self.fout.write(struct.pack("<Q", self.offset_tensor))
# TODO: update offset with alignment
# probably we need a dict as a class attribute to hold tensor data while writing
def flush(self):
self.buffered_writer.flush()
self.fout.flush()
def close(self):
self.buffered_writer.close()
self.fout.close()
def write_architecture(self, architecture: str):
self.write_string(constants.KEY_GENERAL_ARCHITECTURE,