mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
gguf : start write tensor info
This commit is contained in:
parent
8332d26123
commit
af1c9966c8
@ -1,5 +1,6 @@
|
||||
GGUF_MAGIC = 0x47475546
|
||||
GGUF_VERSION = 1
|
||||
GGUF_DEFAULT_ALIGNMENT = 32
|
||||
|
||||
# general
|
||||
KEY_GENERAL_ARCHITECTURE = "general.architecture"
|
||||
|
31
gguf.py
31
gguf.py
@ -7,7 +7,7 @@
|
||||
|
||||
import struct
|
||||
from enum import IntEnum
|
||||
from typing import List, Any
|
||||
from typing import List, Any, Sequence
|
||||
import constants
|
||||
|
||||
|
||||
@ -57,14 +57,15 @@ class GGUFValueType(IntEnum):
|
||||
|
||||
|
||||
class GGUFWriter:
|
||||
def __init__(self, buffered_writer):
|
||||
self.buffered_writer = buffered_writer
|
||||
def __init__(self, fout):
|
||||
self.fout = fout
|
||||
self.offset_tensor = 0
|
||||
|
||||
def write_header(self, tensor_count: int, metadata_kv_count: int):
|
||||
self.buffered_writer.write(struct.pack("<I", constants.GGUF_MAGIC))
|
||||
self.buffered_writer.write(struct.pack("<I", constants.GGUF_VERSION))
|
||||
self.buffered_writer.write(struct.pack("<I", tensor_count))
|
||||
self.buffered_writer.write(struct.pack("<I", metadata_kv_count))
|
||||
self.fout.write(struct.pack("<I", constants.GGUF_MAGIC))
|
||||
self.fout.write(struct.pack("<I", constants.GGUF_VERSION))
|
||||
self.fout.write(struct.pack("<I", tensor_count))
|
||||
self.fout.write(struct.pack("<I", metadata_kv_count))
|
||||
|
||||
@classmethod
|
||||
def open(cls, path: str) -> "GGUFWriter":
|
||||
@ -150,11 +151,23 @@ class GGUFWriter:
|
||||
else:
|
||||
raise ValueError("Invalid GGUF metadata value type")
|
||||
|
||||
def write_tensor_info(self, name: str, shape: Sequence[int], dtype: GGMLQuantizationType):
|
||||
self.write_value(name, GGUFValueType.STRING)
|
||||
n_dims = len(shape)
|
||||
self.write_value(n_dims, GGUFValueType.INT32)
|
||||
for i in range(n_dims):
|
||||
self.write_value(shape[n_dims - 1 - i], GGUFValueType.INT32)
|
||||
|
||||
self.fout.write(struct.pack("<Q", self.offset_tensor))
|
||||
|
||||
# TODO: update offset with alignment
|
||||
# probably we need a dict as a class attribute to hold tensor data while writing
|
||||
|
||||
def flush(self):
|
||||
self.buffered_writer.flush()
|
||||
self.fout.flush()
|
||||
|
||||
def close(self):
|
||||
self.buffered_writer.close()
|
||||
self.fout.close()
|
||||
|
||||
def write_architecture(self, architecture: str):
|
||||
self.write_string(constants.KEY_GENERAL_ARCHITECTURE,
|
||||
|
Loading…
Reference in New Issue
Block a user