diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d6e5dece0..cd875fa4a 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -240,23 +240,6 @@ class Model: return False def write_tensors(self): - # same as ggml_compute_fp32_to_bf16 in ggml-impl.h - def np_fp32_to_bf16(n: np.ndarray): - # force nan to quiet - n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n) - # flush subnormals to zero - n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n) - # round to nearest even - n = (n + (0x7fff + ((n >> 16) & 1))) >> 16 - return n.astype(np.int16) - - # Doing this row-wise is much, much faster than element-wise, hence the signature - v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)") - if self.lazy: - # TODO: find a way to implicitly wrap np.vectorize functions - # NOTE: the type is changed to reflect otypes passed to np.vectorize above - v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16) - max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") for name, data_torch in self.get_tensors(): @@ -309,27 +292,31 @@ class Model: )) if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: - if self.ftype == gguf.LlamaFileType.MOSTLY_F16: + if self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data = gguf.quantize_bf16(data) + assert data.dtype == np.int16 + data_qtype = gguf.GGMLQuantizationType.BF16 + + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): + data = gguf.quantize_q8_0(data) + assert data.dtype == np.uint8 + data_qtype = gguf.GGMLQuantizationType.Q8_0 + + else: # default to float16 for quantized tensors if data_dtype != np.float16: data = data.astype(np.float16) data_qtype = gguf.GGMLQuantizationType.F16 - elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: - if data_dtype != np.float32: - data = data.astype(np.float32) - data = v_fp32_to_bf16(data.view(np.int32)) - assert data.dtype == np.int16 - data_qtype = gguf.GGMLQuantizationType.BF16 - - else: # by default, convert to float32 + if data_qtype is None: # by default, convert to float32 if data_dtype != np.float32: data = data.astype(np.float32) data_qtype = gguf.GGMLQuantizationType.F32 - assert data_qtype is not None - + block_size, type_size = gguf.GGML_QUANT_SIZES[data_qtype] # reverse shape to make it similar to the internal ggml dimension order - shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" + shape_str = f"""{{{', '.join(str(n) for n in reversed( + (*data.shape[:-1], data.shape[-1] * data.dtype.itemsize // type_size * block_size)) + )}}}""" # n_dims is implicit in the shape logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") @@ -859,6 +846,7 @@ class BaichuanModel(Model): self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: if self.hparams["rope_scaling"].get("type") == "linear": @@ -981,6 +969,7 @@ class XverseModel(Model): self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: if self.hparams["rope_scaling"].get("type") == "linear": @@ -1215,6 +1204,7 @@ class StableLMModel(Model): self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True) self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"])) + self.gguf_writer.add_file_type(self.ftype) _q_norms: list[dict[str, Tensor]] | None = None _k_norms: list[dict[str, Tensor]] | None = None @@ -1591,6 +1581,7 @@ class QwenModel(Model): self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) @Model.register("Qwen2ForCausalLM") @@ -1828,6 +1819,7 @@ class PlamoModel(Model): self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) def shuffle_attn_q_weight(self, data_torch): assert data_torch.size() == (5120, 5120) @@ -2007,6 +1999,7 @@ in chat mode so that the conversation can end normally.") self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_file_type(self.ftype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: num_heads = self.hparams["num_attention_heads"] @@ -2415,25 +2408,15 @@ class LazyTorchTensor(gguf.LazyBase): def numpy(self) -> gguf.LazyNumpyTensor: dtype = self._dtype_map[self.dtype] return gguf.LazyNumpyTensor( - meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)), + meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), lazy=self._lazy, args=(self,), func=(lambda s: s[0].numpy()) ) @classmethod - def eager_to_meta(cls, t: Tensor) -> Tensor: - if t.is_meta: - return t - return t.detach().to("meta") - - @classmethod - def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor: - m = m.detach() - if not m.is_meta: - m = m.to("meta") - m.dtype = dtype - return m + def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor: + return torch.empty(size=shape, dtype=dtype, device="meta") @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -2464,8 +2447,8 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( "--bigendian", action="store_true", @@ -2523,6 +2506,7 @@ def main() -> None: "f32": gguf.LlamaFileType.ALL_F32, "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, "auto": gguf.LlamaFileType.GUESSED, } diff --git a/gguf-py/gguf/__init__.py b/gguf-py/gguf/__init__.py index e5d5806c8..ea5146b16 100644 --- a/gguf-py/gguf/__init__.py +++ b/gguf-py/gguf/__init__.py @@ -2,5 +2,6 @@ from .constants import * from .lazy import * from .gguf_reader import * from .gguf_writer import * +from .quants import * from .tensor_mapping import * from .vocab import * diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 96574358d..d5e323a52 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -13,6 +13,7 @@ from string import ascii_letters, digits import numpy as np from .constants import ( + GGML_QUANT_SIZES, GGUF_DEFAULT_ALIGNMENT, GGUF_MAGIC, GGUF_VERSION, @@ -195,7 +196,7 @@ class GGUFWriter: return ((x + n - 1) // n) * n def add_tensor_info( - self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], + self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None, ) -> None: if self.state is not WriterState.EMPTY: @@ -208,10 +209,6 @@ class GGUFWriter: encoded_name = name.encode("utf-8") self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data += encoded_name - n_dims = len(tensor_shape) - self.ti_data += self._pack("I", n_dims) - for i in range(n_dims): - self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) if raw_dtype is None: if tensor_dtype == np.float16: dtype = GGMLQuantizationType.F16 @@ -231,6 +228,15 @@ class GGUFWriter: raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now") else: dtype = raw_dtype + if tensor_dtype == np.uint8: + block_size, type_size = GGML_QUANT_SIZES[raw_dtype] + if tensor_shape[-1] % type_size != 0: + raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})") + tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,) + n_dims = len(tensor_shape) + self.ti_data += self._pack("I", n_dims) + for i in range(n_dims): + self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) self.ti_data += self._pack("I", dtype) self.ti_data += self._pack("Q", self.offset_tensor) self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index 650bea11c..1167335b8 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -6,6 +6,7 @@ from typing import Any, Callable from collections import deque import numpy as np +from numpy._typing import _Shape from numpy.typing import DTypeLike @@ -110,7 +111,7 @@ class LazyBase(ABC, metaclass=LazyMeta): return o @classmethod - def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]: + def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]: def wrapped_fn(*args, **kwargs): if kwargs is None: kwargs = {} @@ -130,9 +131,14 @@ class LazyBase(ABC, metaclass=LazyMeta): res = args[0] assert isinstance(res, cls) res = res._meta - # allow operations to override the dtype + # allow operations to override the dtype and shape if meta_noop is not True: - res = cls.meta_with_dtype(res, meta_noop) + if isinstance(meta_noop, tuple): + dtype, shape = meta_noop + assert callable(shape) + res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape)) + else: + res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) if isinstance(res, cls._tensor_type): def collect_replace(t: LazyBase): @@ -168,7 +174,12 @@ class LazyBase(ABC, metaclass=LazyMeta): while _t._data is None: lt = _t._lazy.popleft() if lt._data is not None: - raise ValueError(f"{lt} did not belong in the lazy queue") + # Lazy tensor did not belong in the lazy queue. + # Weirdly only happens with Bloom models... + # likely because tensors aren't unique in the queue. + # The final output is still the same as in eager mode, + # so it's safe to ignore this. + continue assert lt._func is not None lt._args = cls._recurse_apply(lt._args, already_eager_to_eager) lt._data = lt._func(lt._args) @@ -183,12 +194,12 @@ class LazyBase(ABC, metaclass=LazyMeta): @classmethod def eager_to_meta(cls, t: Any) -> Any: - return cls.meta_with_dtype(t, t.dtype) + return cls.meta_with_dtype_and_shape(t.dtype, t.shape) # must be overridden, meta tensor init is backend-specific @classmethod @abstractmethod - def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass + def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass @classmethod def from_eager(cls, t: Any) -> Any: @@ -205,15 +216,15 @@ class LazyNumpyTensor(LazyBase): _tensor_type = np.ndarray @classmethod - def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]: + def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]: # The initial idea was to use np.nan as the fill value, # but non-float types like np.int16 can't use that. # So zero it is. cheat = np.zeros(1, dtype) - return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape)) + return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape)) def astype(self, dtype, *args, **kwargs): - meta = type(self).meta_with_dtype(self._meta, dtype) + meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) full_args = (self, dtype,) + args # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs))) diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py new file mode 100644 index 000000000..e7fc0eae3 --- /dev/null +++ b/gguf-py/gguf/quants.py @@ -0,0 +1,109 @@ +from __future__ import annotations +from typing import Callable + +from numpy.typing import DTypeLike + +from .constants import GGML_QUANT_SIZES, GGMLQuantizationType +from .lazy import LazyNumpyTensor + +import numpy as np + + +# same as ggml_compute_fp32_to_bf16 in ggml-impl.h +def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray: + n = n.astype(np.float32, copy=False).view(np.int32) + # force nan to quiet + n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n) + # flush subnormals to zero + n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n) + # round to nearest even + n = (n + (0x7fff + ((n >> 16) & 1))) >> 16 + return n.astype(np.int16) + + +# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time +def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray: + rows = arr.reshape((-1, arr.shape[-1])) + osize = 1 + for dim in oshape: + osize *= dim + out = np.empty(shape=osize, dtype=otype) + # compute over groups of 16 rows (arbitrary, but seems good for performance) + n_groups = rows.shape[0] // 16 + np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out) + return out.reshape(oshape) + + +def __quantize_bf16_array(n: np.ndarray) -> np.ndarray: + return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.int16, oshape=n.shape) + + +__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.int16) + + +def quantize_bf16(n: np.ndarray): + if type(n) is LazyNumpyTensor: + return __quantize_bf16_lazy(n) + else: + return __quantize_bf16_array(n) + + +__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0] + + +def can_quantize_to_q8_0(n: np.ndarray) -> bool: + return n.shape[-1] % __q8_block_size == 0 + + +# round away from zero +# ref: https://stackoverflow.com/a/59143326/22827863 +def np_roundf(n: np.ndarray) -> np.ndarray: + a = abs(n) + floored = np.floor(a) + b = floored + np.floor(2 * (a - floored)) + return np.sign(n) * b + + +def __quantize_q8_0_shape_change(s: tuple[int, ...]) -> tuple[int, ...]: + return (*s[:-1], s[-1] // __q8_block_size * __q8_type_size) + + +# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c +def __quantize_q8_0_rows(n: np.ndarray) -> np.ndarray: + shape = n.shape + assert shape[-1] % __q8_block_size == 0 + + n_blocks = n.size // __q8_block_size + + blocks = n.reshape((n_blocks, __q8_block_size)).astype(np.float32, copy=False) + + d = abs(blocks).max(axis=1, keepdims=True) / 127 + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + qs = np_roundf(blocks * id) + + # (n_blocks, 2) + d = d.astype(np.float16).view(np.uint8) + # (n_blocks, block_size) + qs = qs.astype(np.int8).view(np.uint8) + + assert d.shape[1] + qs.shape[1] == __q8_type_size + + return np.concatenate([d, qs], axis=1).reshape(__quantize_q8_0_shape_change(shape)) + + +def __quantize_q8_0_array(n: np.ndarray) -> np.ndarray: + return __apply_over_grouped_rows(__quantize_q8_0_rows, arr=n, otype=np.uint8, oshape=__quantize_q8_0_shape_change(n.shape)) + + +__quantize_q8_0_lazy = LazyNumpyTensor._wrap_fn( + __quantize_q8_0_array, + meta_noop=(np.uint8, __quantize_q8_0_shape_change), +) + + +def quantize_q8_0(data: np.ndarray): + if type(data) is LazyNumpyTensor: + return __quantize_q8_0_lazy(data) + else: + return __quantize_q8_0_array(data)