From af1c9966c8ef1d160bc83bb8df898f5897a19d34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Yusuf=20Sar=C4=B1g=C3=B6z?= Date: Thu, 27 Jul 2023 10:32:31 +0300 Subject: [PATCH] gguf : start write tensor info --- constants.py | 1 + gguf.py | 31 ++++++++++++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/constants.py b/constants.py index 7c7456403..3f4e42b10 100644 --- a/constants.py +++ b/constants.py @@ -1,5 +1,6 @@ GGUF_MAGIC = 0x47475546 GGUF_VERSION = 1 +GGUF_DEFAULT_ALIGNMENT = 32 # general KEY_GENERAL_ARCHITECTURE = "general.architecture" diff --git a/gguf.py b/gguf.py index 991bbe2f3..ec1ee4751 100644 --- a/gguf.py +++ b/gguf.py @@ -7,7 +7,7 @@ import struct from enum import IntEnum -from typing import List, Any +from typing import List, Any, Sequence import constants @@ -57,14 +57,15 @@ class GGUFValueType(IntEnum): class GGUFWriter: - def __init__(self, buffered_writer): - self.buffered_writer = buffered_writer + def __init__(self, fout): + self.fout = fout + self.offset_tensor = 0 def write_header(self, tensor_count: int, metadata_kv_count: int): - self.buffered_writer.write(struct.pack(" "GGUFWriter": @@ -150,11 +151,23 @@ class GGUFWriter: else: raise ValueError("Invalid GGUF metadata value type") + def write_tensor_info(self, name: str, shape: Sequence[int], dtype: GGMLQuantizationType): + self.write_value(name, GGUFValueType.STRING) + n_dims = len(shape) + self.write_value(n_dims, GGUFValueType.INT32) + for i in range(n_dims): + self.write_value(shape[n_dims - 1 - i], GGUFValueType.INT32) + + self.fout.write(struct.pack("