diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index fbaed64da..ec7f4dd75 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -12,7 +12,7 @@ import sys from enum import IntEnum from pathlib import Path from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast import numpy as np import torch @@ -48,7 +48,6 @@ class Model: dir_model: Path ftype: int - fname_out: Path is_big_endian: bool endianess: gguf.GGUFEndian use_temp_file: bool @@ -56,20 +55,20 @@ class Model: part_names: list[str] is_safetensors: bool hparams: dict[str, Any] - gguf_writer: gguf.GGUFWriter block_count: int tensor_map: gguf.TensorNameMap tensor_names: set[str] | None + fname_out: Path + gguf_writer: gguf.GGUFWriter # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): - if self.__class__ == Model: - raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated") + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): + if type(self) is Model: + raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") self.dir_model = dir_model self.ftype = ftype - self.fname_out = fname_out self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file @@ -79,10 +78,23 @@ class Model: if not self.is_safetensors: self.part_names = Model.get_model_part_names(self.dir_model, ".bin") self.hparams = Model.load_hparams(self.dir_model) - self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_names = None + if self.ftype == gguf.LlamaFileType.GUESSED: + # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. + _, first_tensor = next(self.get_tensors()) + if first_tensor.dtype == torch.float16: + logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})") + self.ftype = gguf.LlamaFileType.MOSTLY_F16 + else: + logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") + self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + ftype_up: str = self.ftype.name.partition("_")[2].upper() + ftype_lw: str = ftype_up.lower() + # allow templating the file name with the output ftype, useful with the "auto" ftype + self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) + self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) @classmethod def __init_subclass__(cls): @@ -142,14 +154,27 @@ class Model: raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}") def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: - name: str = gguf.TENSOR_NAMES[key] if key not in gguf.MODEL_TENSORS[self.model_arch]: raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") + name: str = gguf.TENSOR_NAMES[key] if "{bid}" in name: assert bid is not None name = name.format(bid=bid) return name + suffix + def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool: + if key not in gguf.MODEL_TENSORS[self.model_arch]: + return False + key_name: str = gguf.TENSOR_NAMES[key] + if "{bid}" in key_name: + if bid is None: + return False + key_name = key_name.format(bid=bid) + else: + if bid is not None: + return False + return name == (key_name + suffix) + def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) if new_name is None: @@ -215,6 +240,23 @@ class Model: return False def write_tensors(self): + # same as ggml_compute_fp32_to_bf16 in ggml-impl.h + def np_fp32_to_bf16(n: np.ndarray): + # force nan to quiet + n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n) + # flush subnormals to zero + n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n) + # round to nearest even + n = (n + (0x7fff + ((n >> 16) & 1))) >> 16 + return n.astype(np.int16) + + # Doing this row-wise is much, much faster than element-wise, hence the signature + v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)") + if self.lazy: + # TODO: find a way to implicitly wrap np.vectorize functions + # NOTE: the type is changed to reflect otypes passed to np.vectorize above + v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16) + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") for name, data_torch in self.get_tensors(): @@ -239,35 +281,60 @@ class Model: data: np.ndarray = data # type hint n_dims = len(data.shape) data_dtype = data.dtype - - # if f32 desired, convert any float16 to float32 - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) + data_qtype: gguf.GGMLQuantizationType | None = None # when both are True, f32 should win extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors - extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight") + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + extra_f32 = any(cond for cond in ( + extra_f32, + n_dims == 1, + new_name.endswith("_norm.weight"), + )) + + # Some tensor types are always in float32 + extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + )) # if f16 desired, convert any float32 2-dim weight tensors to float16 - extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2) + extra_f16 = any(cond for cond in ( + extra_f16, + (name.endswith(".weight") and n_dims >= 2), + )) - # when both extra_f32 and extra_f16 are False, convert to float32 by default - if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16): - data = data.astype(np.float32) + if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: + if self.ftype == gguf.LlamaFileType.MOSTLY_F16: + if data_dtype != np.float16: + data = data.astype(np.float16) + data_qtype = gguf.GGMLQuantizationType.F16 - if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32: - data = data.astype(np.float16) + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + if data_dtype != np.float32: + data = data.astype(np.float32) + data = v_fp32_to_bf16(data.view(np.int32)) + assert data.dtype == np.int16 + data_qtype = gguf.GGMLQuantizationType.BF16 + + else: # by default, convert to float32 + if data_dtype != np.float32: + data = data.astype(np.float32) + data_qtype = gguf.GGMLQuantizationType.F32 + + assert data_qtype is not None # reverse shape to make it similar to the internal ggml dimension order shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" # n_dims is implicit in the shape - logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}") + logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") - self.gguf_writer.add_tensor(new_name, data) + self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) def write(self): self.write_tensors() @@ -2044,12 +2111,6 @@ class BertModel(Model): return [(self.map_tensor_name(name), data_torch)] - def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: - del new_name, bid, n_dims # unused - - # not used with get_rows, must be F32 - return name == "embeddings.token_type_embeddings.weight" - @Model.register("NomicBertModel") class NomicBertModel(BertModel): @@ -2339,92 +2400,40 @@ class JinaBertV2Model(BertModel): # tree of lazy tensors -class LazyTorchTensor: - _meta: Tensor - _data: Tensor | None - _args: tuple - _func: Callable[[tuple], Tensor] | None - - def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None): - self._meta = meta - self._data = data - self._args = args - self._func = func - - @staticmethod - def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any: - # TODO: dict and set - if isinstance(o, (list, tuple)): - L = [] - for item in o: - L.append(LazyTorchTensor._recurse_apply(item, fn)) - if isinstance(o, tuple): - L = tuple(L) - return L - elif isinstance(o, LazyTorchTensor): - return fn(o) - else: - return o - - def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]: - def wrapped_fn(*args, **kwargs): - if kwargs is None: - kwargs = {} - args = ((self,) if use_self else ()) + args - - meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta) - - return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs)) - return wrapped_fn - - def __getattr__(self, __name: str) -> Any: - meta_attr = getattr(self._meta, __name) - if callable(meta_attr): - return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True) - elif isinstance(meta_attr, torch.Tensor): - # for things like self.T - return self._wrap_fn(lambda s: getattr(s, __name))(self) - else: - return meta_attr +class LazyTorchTensor(gguf.LazyBase): + _tensor_type = torch.Tensor + # to keep the type-checker happy + dtype: torch.dtype + shape: torch.Size + # only used when converting a torch.Tensor to a np.ndarray _dtype_map: dict[torch.dtype, type] = { torch.float16: np.float16, torch.float32: np.float32, } - def numpy(self) -> gguf.LazyTensor: + def numpy(self) -> gguf.LazyNumpyTensor: dtype = self._dtype_map[self.dtype] - return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape) + return gguf.LazyNumpyTensor( + meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)), + lazy=self._lazy, + args=(self,), + func=(lambda s: s[0].numpy()) + ) - @overload - @staticmethod - def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ... - - @overload - @staticmethod - def to_eager(t: tuple) -> tuple: ... - - @staticmethod - def to_eager(t: Any) -> Any: - def simple_to_eager(_t: LazyTorchTensor) -> Tensor: - # wake up the lazy tensor - if _t._data is None and _t._func is not None: - # recurse into its arguments - _t._args = LazyTorchTensor.to_eager(_t._args) - _t._data = _t._func(_t._args) - if _t._data is not None: - return _t._data - else: - raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}") - - # recurse into lists and/or tuples, keeping their structure - return LazyTorchTensor._recurse_apply(t, simple_to_eager) - - @staticmethod - def from_eager(t: Tensor) -> Tensor: - if (t.__class__ == LazyTorchTensor): + @classmethod + def eager_to_meta(cls, t: Tensor) -> Tensor: + if t.is_meta: return t - return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore + return t.detach().to("meta") + + @classmethod + def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor: + m = m.detach() + if not m.is_meta: + m = m.to("meta") + m.dtype = dtype + return m @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -2435,28 +2444,8 @@ class LazyTorchTensor: if func is torch.Tensor.numpy: return args[0].numpy() - if func is torch.equal: - eager_args = LazyTorchTensor.to_eager(args) - return func(*eager_args, **kwargs) - return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs) - - # special methods bypass __getattr__, so they need to be added manually - # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup - # NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used - # as self._meta is currently used), because then the following - # operations would by default not be wrapped, and so not propagated - # when the tensor is made eager. - # It's better to get non-silent errors for not-yet-supported operators. - # TODO: add more when needed to avoid clutter, or find a more concise way - def __neg__(self, *args): # mamba - return self._wrap_fn(torch.Tensor.__neg__)(self, *args) - - def __add__(self, *args): # gemma - return self._wrap_fn(torch.Tensor.__add__)(self, *args) - - def __getitem__(self, *args): # bloom falcon refact internlm2 - return self._wrap_fn(torch.Tensor.__getitem__)(self, *args) + return LazyTorchTensor._wrap_fn(func)(*args, **kwargs) def parse_args() -> argparse.Namespace: @@ -2472,11 +2461,11 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--outfile", type=Path, - help="path to write to; default: based on input", + help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16"], default="f16", - help="output format - use f32 for float32, f16 for float16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( "--bigendian", action="store_true", @@ -2530,16 +2519,18 @@ def main() -> None: logger.error(f'Error: {args.model} is not a directory') sys.exit(1) - ftype_map = { - "f32": gguf.GGMLQuantizationType.F32, - "f16": gguf.GGMLQuantizationType.F16, + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "auto": gguf.LlamaFileType.GUESSED, } if args.outfile is not None: fname_out = args.outfile else: # output in the same directory as the model by default - fname_out = dir_model / f'ggml-model-{args.outtype}.gguf' + fname_out = dir_model / 'ggml-model-{ftype}.gguf' logger.info(f"Loading model: {dir_model.name}") @@ -2555,14 +2546,16 @@ def main() -> None: logger.info("Set model tokenizer") model_instance.set_vocab() + model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + if args.vocab_only: - logger.info(f"Exporting model vocab to '{fname_out}'") + logger.info(f"Exporting model vocab to '{model_instance.fname_out}'") model_instance.write_vocab() else: - logger.info(f"Exporting model to '{fname_out}'") + logger.info(f"Exporting model to '{model_instance.fname_out}'") model_instance.write() - logger.info(f"Model successfully exported to '{fname_out}'") + logger.info(f"Model successfully exported to '{model_instance.fname_out}'") if __name__ == '__main__': diff --git a/gguf-py/gguf/__init__.py b/gguf-py/gguf/__init__.py index 110ab342c..e5d5806c8 100644 --- a/gguf-py/gguf/__init__.py +++ b/gguf-py/gguf/__init__.py @@ -1,4 +1,5 @@ from .constants import * +from .lazy import * from .gguf_reader import * from .gguf_writer import * from .tensor_mapping import * diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a4fbfc5e0..978fcada3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -10,6 +10,7 @@ from typing import Any GGUF_MAGIC = 0x46554747 # "GGUF" GGUF_VERSION = 3 GGUF_DEFAULT_ALIGNMENT = 32 +GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h # # metadata keys @@ -838,6 +839,49 @@ class GGMLQuantizationType(IntEnum): BF16 = 30 +# TODO: add GGMLFileType from ggml_ftype in ggml.h + + +# from llama_ftype in llama.h +# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. +class LlamaFileType(IntEnum): + ALL_F32 = 0 + MOSTLY_F16 = 1 # except 1d tensors + MOSTLY_Q4_0 = 2 # except 1d tensors + MOSTLY_Q4_1 = 3 # except 1d tensors + MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 + # MOSTLY_Q4_2 = 5 # support has been removed + # MOSTLY_Q4_3 = 6 # support has been removed + MOSTLY_Q8_0 = 7 # except 1d tensors + MOSTLY_Q5_0 = 8 # except 1d tensors + MOSTLY_Q5_1 = 9 # except 1d tensors + MOSTLY_Q2_K = 10 # except 1d tensors + MOSTLY_Q3_K_S = 11 # except 1d tensors + MOSTLY_Q3_K_M = 12 # except 1d tensors + MOSTLY_Q3_K_L = 13 # except 1d tensors + MOSTLY_Q4_K_S = 14 # except 1d tensors + MOSTLY_Q4_K_M = 15 # except 1d tensors + MOSTLY_Q5_K_S = 16 # except 1d tensors + MOSTLY_Q5_K_M = 17 # except 1d tensors + MOSTLY_Q6_K = 18 # except 1d tensors + MOSTLY_IQ2_XXS = 19 # except 1d tensors + MOSTLY_IQ2_XS = 20 # except 1d tensors + MOSTLY_Q2_K_S = 21 # except 1d tensors + MOSTLY_IQ3_XS = 22 # except 1d tensors + MOSTLY_IQ3_XXS = 23 # except 1d tensors + MOSTLY_IQ1_S = 24 # except 1d tensors + MOSTLY_IQ4_NL = 25 # except 1d tensors + MOSTLY_IQ3_S = 26 # except 1d tensors + MOSTLY_IQ3_M = 27 # except 1d tensors + MOSTLY_IQ2_S = 28 # except 1d tensors + MOSTLY_IQ2_M = 29 # except 1d tensors + MOSTLY_IQ4_XS = 30 # except 1d tensors + MOSTLY_IQ1_M = 31 # except 1d tensors + MOSTLY_BF16 = 32 # except 1d tensors + + GUESSED = 1024 # not specified in the model file + + class GGUFEndian(IntEnum): LITTLE = 0 BIG = 1 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8dcf9330b..96574358d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -7,7 +7,7 @@ import struct import tempfile from enum import Enum, auto from io import BufferedWriter -from typing import IO, Any, Callable, Sequence, Mapping +from typing import IO, Any, Sequence, Mapping from string import ascii_letters, digits import numpy as np @@ -28,47 +28,6 @@ from .constants import ( logger = logging.getLogger(__name__) -class LazyTensor: - data: Callable[[], np.ndarray[Any, Any]] - # to avoid too deep recursion - functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]] - dtype: np.dtype[Any] - shape: tuple[int, ...] - - def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]): - self.data = data - self.functions = [] - self.dtype = np.dtype(dtype) - self.shape = shape - - def astype(self, dtype: type, **kwargs) -> LazyTensor: - self.functions.append(lambda n: n.astype(dtype, **kwargs)) - self.dtype = np.dtype(dtype) - return self - - @property - def nbytes(self) -> int: - size = 1 - for n in self.shape: - size *= n - return size * self.dtype.itemsize - - def tofile(self, *args, **kwargs) -> None: - data = self.data() - for f in self.functions: - data = f(data) - assert data.shape == self.shape - assert data.dtype == self.dtype - assert data.nbytes == self.nbytes - self.functions = [] - self.data = lambda: data - data.tofile(*args, **kwargs) - - def byteswap(self, *args, **kwargs) -> LazyTensor: - self.functions.append(lambda n: n.byteswap(*args, **kwargs)) - return self - - class WriterState(Enum): EMPTY = auto() HEADER = auto() @@ -79,7 +38,7 @@ class WriterState(Enum): class GGUFWriter: fout: BufferedWriter temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: list[np.ndarray[Any, Any] | LazyTensor] + tensors: list[np.ndarray[Any, Any]] _simple_value_packing = { GGUFValueType.UINT8: "B", GGUFValueType.INT8: "b", @@ -278,7 +237,7 @@ class GGUFWriter: self.ti_data_count += 1 def add_tensor( - self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None, + self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: if self.endianess == GGUFEndian.BIG: @@ -303,7 +262,7 @@ class GGUFWriter: if pad != 0: fp.write(bytes([0] * pad)) - def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None: + def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: if self.state is not WriterState.TI_DATA: raise ValueError(f'Expected output file to contain tensor info, got {self.state}') @@ -391,7 +350,7 @@ class GGUFWriter: def add_name(self, name: str) -> None: self.add_string(Keys.General.NAME, name) - def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None: + def add_quantization_version(self, quantization_version: int) -> None: self.add_uint32( Keys.General.QUANTIZATION_VERSION, quantization_version) diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py new file mode 100644 index 000000000..650bea11c --- /dev/null +++ b/gguf-py/gguf/lazy.py @@ -0,0 +1,225 @@ +from __future__ import annotations +from abc import ABC, ABCMeta, abstractmethod + +import logging +from typing import Any, Callable +from collections import deque + +import numpy as np +from numpy.typing import DTypeLike + + +logger = logging.getLogger(__name__) + + +class LazyMeta(ABCMeta): + + def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs): + def __getattr__(self, __name: str) -> Any: + meta_attr = getattr(self._meta, __name) + if callable(meta_attr): + return type(self)._wrap_fn( + (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)), + use_self=self, + ) + elif isinstance(meta_attr, self._tensor_type): + # e.g. self.T with torch.Tensor should still be wrapped + return type(self)._wrap_fn(lambda s: getattr(s, __name))(self) + else: + # no need to wrap non-tensor properties, + # and they likely don't depend on the actual contents of the tensor + return meta_attr + + namespace["__getattr__"] = __getattr__ + + # need to make a builder for the wrapped wrapper to copy the name, + # or else it fails with very cryptic error messages, + # because somehow the same string would end up in every closures + def mk_wrap(op_name: str, *, meta_noop: bool = False): + # need to wrap the wrapper to get self + def wrapped_special_op(self, *args, **kwargs): + return type(self)._wrap_fn( + getattr(type(self)._tensor_type, op_name), + meta_noop=meta_noop, + )(self, *args, **kwargs) + return wrapped_special_op + + # special methods bypass __getattr__, so they need to be added manually + # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup + # NOTE: doing this from a metaclass is very convenient + # TODO: make this even more comprehensive + for binary_op in ( + "lt", "le", "eq", "ne", "ge", "gt", "not" + "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul", + "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor", + "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor", + "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor", + ): + attr_name = f"__{binary_op}__" + # the result of these operators usually has the same shape and dtype as the input, + # so evaluation on the meta tensor can be skipped. + namespace[attr_name] = mk_wrap(attr_name, meta_noop=True) + + for special_op in ( + "getitem", "setitem", "len", + ): + attr_name = f"__{special_op}__" + namespace[attr_name] = mk_wrap(attr_name, meta_noop=False) + + return super().__new__(cls, name, bases, namespace, **kwargs) + + +# Tree of lazy tensors +class LazyBase(ABC, metaclass=LazyMeta): + _tensor_type: type + _meta: Any + _data: Any | None + _lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager + _args: tuple + _func: Callable[[tuple], Any] | None + + def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None): + super().__init__() + self._meta = meta + self._data = data + self._lazy = lazy if lazy is not None else deque() + self._args = args + self._func = func + assert self._func is not None or self._data is not None + if self._data is None: + self._lazy.append(self) + + def __init_subclass__(cls) -> None: + if "_tensor_type" not in cls.__dict__: + raise TypeError(f"property '_tensor_type' must be defined for {cls!r}") + return super().__init_subclass__() + + @staticmethod + def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any: + # TODO: dict and set + if isinstance(o, (list, tuple)): + L = [] + for item in o: + L.append(LazyBase._recurse_apply(item, fn)) + if isinstance(o, tuple): + L = tuple(L) + return L + elif isinstance(o, LazyBase): + return fn(o) + else: + return o + + @classmethod + def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]: + def wrapped_fn(*args, **kwargs): + if kwargs is None: + kwargs = {} + args = ((use_self,) if use_self is not None else ()) + args + + meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) + + if isinstance(meta_noop, bool) and not meta_noop: + try: + res = fn(*meta_args, **kwargs) + except NotImplementedError: + # running some operations on PyTorch's Meta tensors can cause this exception + res = None + else: + # some operators don't need to actually run on the meta tensors + assert len(args) > 0 + res = args[0] + assert isinstance(res, cls) + res = res._meta + # allow operations to override the dtype + if meta_noop is not True: + res = cls.meta_with_dtype(res, meta_noop) + + if isinstance(res, cls._tensor_type): + def collect_replace(t: LazyBase): + if collect_replace.shared_lazy is None: + collect_replace.shared_lazy = t._lazy + else: + collect_replace.shared_lazy.extend(t._lazy) + t._lazy = collect_replace.shared_lazy + + # emulating a static variable + collect_replace.shared_lazy = None + + LazyBase._recurse_apply(args, collect_replace) + + shared_lazy = collect_replace.shared_lazy + + return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs)) + else: + del res # not needed + # non-tensor return likely relies on the contents of the args + # (e.g. the result of torch.equal) + eager_args = cls.to_eager(args) + return fn(*eager_args, **kwargs) + return wrapped_fn + + @classmethod + def to_eager(cls, t: Any) -> Any: + def simple_to_eager(_t: LazyBase) -> Any: + def already_eager_to_eager(_t: LazyBase) -> Any: + assert _t._data is not None + return _t._data + + while _t._data is None: + lt = _t._lazy.popleft() + if lt._data is not None: + raise ValueError(f"{lt} did not belong in the lazy queue") + assert lt._func is not None + lt._args = cls._recurse_apply(lt._args, already_eager_to_eager) + lt._data = lt._func(lt._args) + # sanity check + assert lt._data.dtype == lt._meta.dtype + assert lt._data.shape == lt._meta.shape + + return _t._data + + # recurse into lists and/or tuples, keeping their structure + return cls._recurse_apply(t, simple_to_eager) + + @classmethod + def eager_to_meta(cls, t: Any) -> Any: + return cls.meta_with_dtype(t, t.dtype) + + # must be overridden, meta tensor init is backend-specific + @classmethod + @abstractmethod + def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass + + @classmethod + def from_eager(cls, t: Any) -> Any: + if type(t) is cls: + # already eager + return t + elif isinstance(t, cls._tensor_type): + return cls(meta=cls.eager_to_meta(t), data=t) + else: + return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}") + + +class LazyNumpyTensor(LazyBase): + _tensor_type = np.ndarray + + @classmethod + def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]: + # The initial idea was to use np.nan as the fill value, + # but non-float types like np.int16 can't use that. + # So zero it is. + cheat = np.zeros(1, dtype) + return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape)) + + def astype(self, dtype, *args, **kwargs): + meta = type(self).meta_with_dtype(self._meta, dtype) + full_args = (self, dtype,) + args + # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. + return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs))) + + def tofile(self, *args, **kwargs): + eager = LazyNumpyTensor.to_eager(self) + return eager.tofile(*args, **kwargs) + + # TODO: __array_function__