#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations from dataclasses import dataclass import logging import argparse import os import sys from pathlib import Path from types import EllipsisType from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast import torch if TYPE_CHECKING: from torch import Tensor if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf # reuse model definitions from convert_hf_to_gguf.py from convert_hf_to_gguf import Model logger = logging.getLogger("lora-to-gguf") @dataclass class PartialLoraTensor: A: Tensor | None = None B: Tensor | None = None # magic to support tensor shape modifications and splitting class LoraTorchTensor: _lora_A: Tensor _lora_B: Tensor _rank: int def __init__(self, A: Tensor, B: Tensor): assert len(A.shape) == len(B.shape) if A.dtype != B.dtype: A = A.to(torch.float32) B = B.to(torch.float32) self._lora_A = A self._lora_B = B assert self._lora_A.shape[-2] == self._lora_B.shape[-1] self._rank = self._lora_B.shape[-1] def __getitem__( self, indices: ( SupportsIndex | slice | tuple[SupportsIndex | slice | EllipsisType | Tensor, ...] ), ) -> LoraTorchTensor: shape = self.shape if isinstance(indices, (SupportsIndex, slice)): if len(shape) > 2: return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) else: raise NotImplementedError elif isinstance(indices, tuple): assert len(indices) > 0 if isinstance(indices[-1], EllipsisType): return self[indices[:-1]] # expand ellipsis indices = tuple( u for v in ( ( (slice(None, None) for _ in range(len(indices) - 1)) if isinstance(i, EllipsisType) else (i,) ) for i in indices ) for u in v ) if len(indices) < len(shape): indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) # TODO: make sure this is correct # lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1]) indices_A = ( *( 0 if isinstance(i, SupportsIndex) else slice(None, None) for i in indices[:-2] ), slice(None, None), indices[-1], ) indices_B = indices[:-1] return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B]) else: raise NotImplementedError @property def dtype(self) -> torch.dtype: assert self._lora_A.dtype == self._lora_B.dtype return self._lora_A.dtype @property def shape(self) -> tuple[int, ...]: return (*self._lora_B.shape[:-1], self._lora_A.shape[-1]) def size(self, dim=None): assert dim is None return self.shape def reshape(self, *shape: int | tuple[int]) -> LoraTorchTensor: if isinstance(shape[0], tuple): new_shape: tuple[int] = shape[0] else: new_shape = cast(tuple[int], shape) orig_shape = self.shape if new_shape[-1] != orig_shape[-1]: raise NotImplementedError return LoraTorchTensor( self._lora_A.reshape((*(1 for _ in new_shape[:-2]), *self._lora_A.shape[-2:])), self._lora_B.reshape((*new_shape[:-1], self._rank)), ) def reshape_as(self, other: Tensor) -> LoraTorchTensor: return self.reshape(*other.shape) def view(self, *size: int) -> LoraTorchTensor: return self.reshape(*size) def permute(self, *dims: int) -> LoraTorchTensor: shape = self.shape dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims) if dims[-1] == -2 and dims[-2] == -1: return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) else: assert dims[-1] == -1 assert all(dim == 1 for dim in self._lora_A.shape[:-2]) return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: shape = self.shape dims = [i for i in range(len(shape))] dims[dim0], dims[dim1] = dims[dim1], dims[dim0] return self.permute(*dims) def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: return self.transpose(axis0, axis1) def to(self, *args, **kwargs): return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) @classmethod def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): del types # unused if kwargs is None: kwargs = {} if func is torch.permute: return type(args[0]).permute(*args, **kwargs) elif func is torch.reshape: return type(args[0]).reshape(*args, **kwargs) elif func is torch.stack: assert isinstance(args[0], Sequence) dim = kwargs.get("dim", 0) assert dim == 0 return LoraTorchTensor( torch.stack([a._lora_A for a in args[0]], dim), torch.stack([b._lora_B for b in args[0]], dim), ) elif func is torch.cat: assert isinstance(args[0], Sequence) dim = kwargs.get("dim", 0) assert dim == 0 if len(args[0][0].shape) > 2: return LoraTorchTensor( torch.cat([a._lora_A for a in args[0]], dim), torch.cat([b._lora_B for b in args[0]], dim), ) else: return LoraTorchTensor( args[0][0]._lora_A, # TODO: is this correct? (can't cat over the rank) torch.cat([b._lora_B for b in args[0]], dim), ) else: raise NotImplementedError def get_base_tensor_name(lora_tensor_name: str) -> str: base_name = lora_tensor_name.replace("base_model.model.", "") base_name = base_name.replace(".lora_A.weight", ".weight") base_name = base_name.replace(".lora_B.weight", ".weight") return base_name def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") parser.add_argument( "--outfile", type=Path, help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0", ) parser.add_argument( "--bigendian", action="store_true", help="model is executed on big endian machine", ) parser.add_argument( "--verbose", action="store_true", help="increase output verbosity", ) parser.add_argument( "--base", type=Path, required=True, help="directory containing base model file", ) parser.add_argument( "lora_path", type=Path, help="directory containing LoRA adapter file", ) return parser.parse_args() if __name__ == '__main__': args = parse_args() logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) ftype_map: dict[str, gguf.LlamaFileType] = { "f32": gguf.LlamaFileType.ALL_F32, "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, } ftype = ftype_map[args.outtype] dir_base_model = args.base dir_lora = args.lora_path input_json = os.path.join(dir_lora, "adapter_config.json") input_model = os.path.join(dir_lora, "adapter_model.safetensors") if args.outfile is not None: fname_out = args.outfile else: # output in the same directory as the model by default fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' if os.path.exists(input_model): # lazy import load_file only if lora is in safetensors format. from safetensors.torch import load_file lora_model = load_file(input_model, device="cpu") else: input_model = os.path.join(dir_lora, "adapter_model.bin") lora_model = torch.load(input_model, map_location="cpu", weights_only=True) # load base model logger.info(f"Loading base model: {dir_base_model.name}") hparams = Model.load_hparams(dir_base_model) with torch.inference_mode(): try: model_class = Model.from_model_architecture(hparams["architectures"][0]) except NotImplementedError: logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) class LoraModel(model_class): model_arch = model_class.model_arch def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_map: dict[str, PartialLoraTensor] = {} for name, tensor in lora_model.items(): base_name = get_base_tensor_name(name) is_lora_a = ".lora_A.weight" in name is_lora_b = ".lora_B.weight" in name if not is_lora_a and not is_lora_b: if ".base_layer.weight" in name: continue logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") sys.exit(1) if base_name in tensor_map: if is_lora_a: tensor_map[base_name].A = tensor else: tensor_map[base_name].B = tensor else: if is_lora_a: tensor_map[base_name] = PartialLoraTensor(A=tensor) else: tensor_map[base_name] = PartialLoraTensor(B=tensor) for name, tensor in tensor_map.items(): assert tensor.A is not None assert tensor.B is not None yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: dest = super().modify_tensors(data_torch, name, bid) for dest_name, dest_data in dest: assert isinstance(dest_data, LoraTorchTensor) # logger.info(f"{orig_name} --> {dest_name}") yield (dest_name + ".lora_a", dest_data._lora_A) yield (dest_name + ".lora_b", dest_data._lora_B) model_instance = LoraModel(dir_base_model, ftype, fname_out, args.bigendian, False, False, None) logger.info("Set model parameters") model_instance.set_gguf_parameters() # adapter_config = json.load(input_json) model_instance.gguf_writer.add_string("training.type", "finetune_lora") model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) logger.info("Exporting model...") model_instance.write() logger.info(f"Model successfully exported to {model_instance.fname_out}")