2024-07-08 16:35:27 +02:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
2024-07-10 00:26:38 +02:00
from dataclasses import dataclass
2024-07-08 16:35:27 +02:00
import logging
import argparse
import os
import sys
2024-07-15 08:35:06 +02:00
import json
from math import prod
2024-07-08 16:35:27 +02:00
from pathlib import Path
2024-07-15 08:35:06 +02:00
from typing import TYPE_CHECKING , Any , Callable , Iterable , Iterator , Sequence , SupportsIndex , cast
2024-07-08 16:35:27 +02:00
import torch
if TYPE_CHECKING :
from torch import Tensor
if ' NO_LOCAL_GGUF ' not in os . environ :
sys . path . insert ( 1 , str ( Path ( __file__ ) . parent / ' gguf-py ' ) )
import gguf
# reuse model definitions from convert_hf_to_gguf.py
2024-07-15 08:35:06 +02:00
from convert_hf_to_gguf import LazyTorchTensor , Model
2024-07-08 16:35:27 +02:00
logger = logging . getLogger ( " lora-to-gguf " )
2024-07-08 21:55:41 +02:00
2024-07-10 00:26:38 +02:00
@dataclass
class PartialLoraTensor :
A : Tensor | None = None
B : Tensor | None = None
# magic to support tensor shape modifications and splitting
class LoraTorchTensor :
2024-07-15 08:35:06 +02:00
_lora_A : Tensor # (n_rank, row_size)
_lora_B : Tensor # (col_size, n_rank)
2024-07-10 00:26:38 +02:00
_rank : int
def __init__ ( self , A : Tensor , B : Tensor ) :
assert len ( A . shape ) == len ( B . shape )
2024-07-15 08:35:06 +02:00
assert A . shape [ - 2 ] == B . shape [ - 1 ]
2024-07-10 00:26:38 +02:00
if A . dtype != B . dtype :
A = A . to ( torch . float32 )
B = B . to ( torch . float32 )
self . _lora_A = A
self . _lora_B = B
2024-07-15 08:35:06 +02:00
self . _rank = B . shape [ - 1 ]
def get_lora_A_B ( self ) - > tuple [ Tensor , Tensor ] :
return ( self . _lora_A , self . _lora_B )
2024-07-10 00:26:38 +02:00
def __getitem__ (
self ,
indices : (
SupportsIndex
| slice
2024-07-15 08:35:06 +02:00
| tuple [ SupportsIndex | slice | Tensor , . . . ] # TODO: add ellipsis in the type signature
2024-07-10 00:26:38 +02:00
) ,
) - > LoraTorchTensor :
shape = self . shape
2024-07-15 08:35:06 +02:00
if isinstance ( indices , SupportsIndex ) :
2024-07-10 00:26:38 +02:00
if len ( shape ) > 2 :
return LoraTorchTensor ( self . _lora_A [ indices ] , self . _lora_B [ indices ] )
else :
2024-07-15 08:35:06 +02:00
raise NotImplementedError # can't return a vector
elif isinstance ( indices , slice ) :
if len ( shape ) > 2 :
return LoraTorchTensor ( self . _lora_A [ indices ] , self . _lora_B [ indices ] )
else :
return LoraTorchTensor ( self . _lora_A , self . _lora_B [ indices ] )
2024-07-10 00:26:38 +02:00
elif isinstance ( indices , tuple ) :
assert len ( indices ) > 0
2024-07-15 08:35:06 +02:00
if indices [ - 1 ] is Ellipsis :
2024-07-10 00:26:38 +02:00
return self [ indices [ : - 1 ] ]
# expand ellipsis
indices = tuple (
u
for v in (
(
( slice ( None , None ) for _ in range ( len ( indices ) - 1 ) )
2024-07-15 08:35:06 +02:00
if i is Ellipsis
2024-07-10 00:26:38 +02:00
else ( i , )
)
for i in indices
)
for u in v
)
if len ( indices ) < len ( shape ) :
indices = ( * indices , * ( slice ( None , None ) for _ in range ( len ( indices ) , len ( shape ) ) ) )
# TODO: make sure this is correct
indices_A = (
* (
2024-07-15 08:35:06 +02:00
(
j . __index__ ( ) % self . _lora_A . shape [ i ]
if isinstance ( j , SupportsIndex )
else slice ( None , None )
)
for i , j in enumerate ( indices [ : - 2 ] )
2024-07-10 00:26:38 +02:00
) ,
slice ( None , None ) ,
indices [ - 1 ] ,
)
indices_B = indices [ : - 1 ]
return LoraTorchTensor ( self . _lora_A [ indices_A ] , self . _lora_B [ indices_B ] )
else :
2024-07-15 08:35:06 +02:00
raise NotImplementedError # unknown indice type
2024-07-10 00:26:38 +02:00
@property
def dtype ( self ) - > torch . dtype :
assert self . _lora_A . dtype == self . _lora_B . dtype
return self . _lora_A . dtype
@property
def shape ( self ) - > tuple [ int , . . . ] :
2024-07-15 08:35:06 +02:00
assert len ( self . _lora_A . shape ) == len ( self . _lora_B . shape )
2024-07-10 00:26:38 +02:00
return ( * self . _lora_B . shape [ : - 1 ] , self . _lora_A . shape [ - 1 ] )
def size ( self , dim = None ) :
assert dim is None
return self . shape
2024-07-15 08:35:06 +02:00
def reshape ( self , * shape : int | tuple [ int , . . . ] ) - > LoraTorchTensor :
2024-07-10 00:26:38 +02:00
if isinstance ( shape [ 0 ] , tuple ) :
2024-07-15 08:35:06 +02:00
new_shape : tuple [ int , . . . ] = shape [ 0 ]
2024-07-10 00:26:38 +02:00
else :
2024-07-15 08:35:06 +02:00
new_shape = cast ( tuple [ int , . . . ] , shape )
2024-07-10 00:26:38 +02:00
orig_shape = self . shape
2024-07-15 08:35:06 +02:00
if len ( new_shape ) < 2 :
raise NotImplementedError # can't become a vector
# expand -1 in the shape
if any ( dim == - 1 for dim in new_shape ) :
n_elems = prod ( orig_shape )
n_new_elems = prod ( dim if dim != - 1 else 1 for dim in new_shape )
assert n_elems % n_new_elems == 0
new_shape = ( * ( dim if dim != - 1 else n_elems / / n_new_elems for dim in new_shape ) , )
2024-07-10 00:26:38 +02:00
if new_shape [ - 1 ] != orig_shape [ - 1 ] :
2024-07-15 08:35:06 +02:00
raise NotImplementedError # can't reshape the row size trivially
shape_A = ( * ( 1 for _ in new_shape [ : - 2 ] ) , self . _rank , orig_shape [ - 1 ] )
shape_B = ( * new_shape [ : - 1 ] , self . _rank )
2024-07-10 00:26:38 +02:00
return LoraTorchTensor (
2024-07-15 08:35:06 +02:00
self . _lora_A . reshape ( shape_A ) ,
self . _lora_B . reshape ( shape_B ) ,
2024-07-10 00:26:38 +02:00
)
def reshape_as ( self , other : Tensor ) - > LoraTorchTensor :
return self . reshape ( * other . shape )
def view ( self , * size : int ) - > LoraTorchTensor :
return self . reshape ( * size )
def permute ( self , * dims : int ) - > LoraTorchTensor :
shape = self . shape
dims = tuple ( dim - len ( shape ) if dim > = 0 else dim for dim in dims )
2024-07-15 08:35:06 +02:00
if dims [ - 1 ] == - 1 :
# TODO: support higher dimensional A shapes bigger than 1
2024-07-10 00:26:38 +02:00
assert all ( dim == 1 for dim in self . _lora_A . shape [ : - 2 ] )
return LoraTorchTensor ( self . _lora_A , self . _lora_B . permute ( * dims ) )
2024-07-15 08:35:06 +02:00
if len ( shape ) == 2 and dims [ - 1 ] == - 2 and dims [ - 2 ] == - 1 :
return LoraTorchTensor ( self . _lora_B . permute ( * dims ) , self . _lora_A . permute ( * dims ) )
else :
# TODO: compose the above two
raise NotImplementedError
2024-07-10 00:26:38 +02:00
def transpose ( self , dim0 : int , dim1 : int ) - > LoraTorchTensor :
shape = self . shape
dims = [ i for i in range ( len ( shape ) ) ]
dims [ dim0 ] , dims [ dim1 ] = dims [ dim1 ] , dims [ dim0 ]
return self . permute ( * dims )
def swapaxes ( self , axis0 : int , axis1 : int ) - > LoraTorchTensor :
return self . transpose ( axis0 , axis1 )
def to ( self , * args , * * kwargs ) :
return LoraTorchTensor ( self . _lora_A . to ( * args , * * kwargs ) , self . _lora_B . to ( * args , * * kwargs ) )
@classmethod
def __torch_function__ ( cls , func : Callable , types , args = ( ) , kwargs = None ) :
del types # unused
if kwargs is None :
kwargs = { }
if func is torch . permute :
return type ( args [ 0 ] ) . permute ( * args , * * kwargs )
elif func is torch . reshape :
return type ( args [ 0 ] ) . reshape ( * args , * * kwargs )
elif func is torch . stack :
assert isinstance ( args [ 0 ] , Sequence )
dim = kwargs . get ( " dim " , 0 )
assert dim == 0
return LoraTorchTensor (
torch . stack ( [ a . _lora_A for a in args [ 0 ] ] , dim ) ,
torch . stack ( [ b . _lora_B for b in args [ 0 ] ] , dim ) ,
)
elif func is torch . cat :
assert isinstance ( args [ 0 ] , Sequence )
dim = kwargs . get ( " dim " , 0 )
assert dim == 0
if len ( args [ 0 ] [ 0 ] . shape ) > 2 :
return LoraTorchTensor (
torch . cat ( [ a . _lora_A for a in args [ 0 ] ] , dim ) ,
torch . cat ( [ b . _lora_B for b in args [ 0 ] ] , dim ) ,
)
2024-07-15 08:35:06 +02:00
elif all ( torch . equal ( args [ 0 ] [ 0 ] . _lora_A , t . _lora_A ) for t in args [ 0 ] [ 1 : ] ) :
2024-07-10 00:26:38 +02:00
return LoraTorchTensor (
2024-07-15 08:35:06 +02:00
args [ 0 ] [ 0 ] . _lora_A ,
2024-07-10 00:26:38 +02:00
torch . cat ( [ b . _lora_B for b in args [ 0 ] ] , dim ) ,
)
2024-07-15 08:35:06 +02:00
else :
raise NotImplementedError
2024-07-10 00:26:38 +02:00
else :
raise NotImplementedError
2024-07-10 00:23:07 +02:00
def get_base_tensor_name ( lora_tensor_name : str ) - > str :
base_name = lora_tensor_name . replace ( " base_model.model. " , " " )
base_name = base_name . replace ( " .lora_A.weight " , " .weight " )
base_name = base_name . replace ( " .lora_B.weight " , " .weight " )
return base_name
2024-07-08 16:35:27 +02:00
def parse_args ( ) - > argparse . Namespace :
parser = argparse . ArgumentParser (
2024-07-08 21:55:41 +02:00
description = " Convert a huggingface PEFT LoRA adapter to a GGML compatible file " )
2024-07-08 16:35:27 +02:00
parser . add_argument (
" --outfile " , type = Path ,
2024-07-08 22:05:35 +02:00
help = " path to write to; default: based on input. {ftype} will be replaced by the outtype. " ,
2024-07-08 16:35:27 +02:00
)
parser . add_argument (
2024-07-15 08:35:06 +02:00
" --outtype " , type = str , choices = [ " f32 " , " f16 " , " bf16 " , " q8_0 " , " auto " ] , default = " f16 " ,
help = " output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type " ,
2024-07-08 16:35:27 +02:00
)
parser . add_argument (
" --bigendian " , action = " store_true " ,
help = " model is executed on big endian machine " ,
)
2024-07-15 08:35:06 +02:00
parser . add_argument (
" --no-lazy " , action = " store_true " ,
help = " use more RAM by computing all outputs before writing (use in case lazy evaluation is broken) " ,
)
2024-07-08 16:35:27 +02:00
parser . add_argument (
" --verbose " , action = " store_true " ,
help = " increase output verbosity " ,
)
parser . add_argument (
" --base " , type = Path , required = True ,
help = " directory containing base model file " ,
)
parser . add_argument (
" lora_path " , type = Path ,
help = " directory containing LoRA adapter file " ,
)
return parser . parse_args ( )
if __name__ == ' __main__ ' :
args = parse_args ( )
logging . basicConfig ( level = logging . DEBUG if args . verbose else logging . INFO )
ftype_map : dict [ str , gguf . LlamaFileType ] = {
" f32 " : gguf . LlamaFileType . ALL_F32 ,
" f16 " : gguf . LlamaFileType . MOSTLY_F16 ,
" bf16 " : gguf . LlamaFileType . MOSTLY_BF16 ,
" q8_0 " : gguf . LlamaFileType . MOSTLY_Q8_0 ,
2024-07-15 08:35:06 +02:00
" auto " : gguf . LlamaFileType . GUESSED ,
2024-07-08 16:35:27 +02:00
}
2024-07-15 08:35:06 +02:00
2024-07-08 21:55:41 +02:00
ftype = ftype_map [ args . outtype ]
2024-07-08 16:35:27 +02:00
2024-07-15 08:35:06 +02:00
dir_base_model : Path = args . base
dir_lora : Path = args . lora_path
lora_config = dir_lora / " adapter_config.json "
input_model = dir_lora / " adapter_model.safetensors "
2024-07-08 16:35:27 +02:00
if args . outfile is not None :
fname_out = args . outfile
else :
# output in the same directory as the model by default
2024-07-08 22:05:35 +02:00
fname_out = dir_lora / ' ggml-lora- {ftype} .gguf '
2024-07-08 16:35:27 +02:00
if os . path . exists ( input_model ) :
# lazy import load_file only if lora is in safetensors format.
from safetensors . torch import load_file
2024-07-10 00:26:38 +02:00
2024-07-08 16:35:27 +02:00
lora_model = load_file ( input_model , device = " cpu " )
2024-07-10 00:26:38 +02:00
else :
input_model = os . path . join ( dir_lora , " adapter_model.bin " )
lora_model = torch . load ( input_model , map_location = " cpu " , weights_only = True )
2024-07-08 16:35:27 +02:00
# load base model
logger . info ( f " Loading base model: { dir_base_model . name } " )
hparams = Model . load_hparams ( dir_base_model )
with torch . inference_mode ( ) :
try :
model_class = Model . from_model_architecture ( hparams [ " architectures " ] [ 0 ] )
except NotImplementedError :
logger . error ( f " Model { hparams [ ' architectures ' ] [ 0 ] } is not supported " )
sys . exit ( 1 )
2024-07-10 00:26:38 +02:00
class LoraModel ( model_class ) :
model_arch = model_class . model_arch
2024-07-08 16:35:27 +02:00
2024-07-10 00:26:38 +02:00
def get_tensors ( self ) - > Iterator [ tuple [ str , Tensor ] ] :
tensor_map : dict [ str , PartialLoraTensor ] = { }
2024-07-08 16:35:27 +02:00
2024-07-10 00:26:38 +02:00
for name , tensor in lora_model . items ( ) :
2024-07-15 08:35:06 +02:00
if self . lazy :
tensor = LazyTorchTensor . from_eager ( tensor )
2024-07-10 00:26:38 +02:00
base_name = get_base_tensor_name ( name )
is_lora_a = " .lora_A.weight " in name
is_lora_b = " .lora_B.weight " in name
if not is_lora_a and not is_lora_b :
if " .base_layer.weight " in name :
continue
logger . error ( f " Unexpected name ' { name } ' : Not a lora_A or lora_B tensor " )
sys . exit ( 1 )
2024-07-08 16:35:27 +02:00
2024-07-10 00:26:38 +02:00
if base_name in tensor_map :
if is_lora_a :
tensor_map [ base_name ] . A = tensor
else :
tensor_map [ base_name ] . B = tensor
else :
if is_lora_a :
tensor_map [ base_name ] = PartialLoraTensor ( A = tensor )
else :
tensor_map [ base_name ] = PartialLoraTensor ( B = tensor )
2024-07-08 16:35:27 +02:00
2024-07-10 00:26:38 +02:00
for name , tensor in tensor_map . items ( ) :
assert tensor . A is not None
assert tensor . B is not None
yield ( name , cast ( torch . Tensor , LoraTorchTensor ( tensor . A , tensor . B ) ) )
2024-07-08 17:05:17 +02:00
2024-07-10 00:26:38 +02:00
def modify_tensors ( self , data_torch : Tensor , name : str , bid : int | None ) - > Iterable [ tuple [ str , Tensor ] ] :
dest = super ( ) . modify_tensors ( data_torch , name , bid )
for dest_name , dest_data in dest :
assert isinstance ( dest_data , LoraTorchTensor )
2024-07-15 08:35:06 +02:00
lora_a , lora_b = dest_data . get_lora_A_B ( )
yield ( dest_name + " .lora_a " , lora_a )
yield ( dest_name + " .lora_b " , lora_b )
model_instance = LoraModel (
dir_base_model ,
ftype ,
fname_out ,
is_big_endian = args . bigendian ,
use_temp_file = False ,
eager = args . no_lazy ,
model_name = None ,
)
2024-07-10 00:26:38 +02:00
logger . info ( " Set model parameters " )
model_instance . set_gguf_parameters ( )
2024-07-10 00:23:07 +02:00
2024-07-15 08:35:06 +02:00
with open ( lora_config , " r " ) as f :
lparams : dict [ str , Any ] = json . load ( f )
alpha = lparams [ " lora_alpha " ]
2024-07-10 00:26:38 +02:00
model_instance . gguf_writer . add_string ( " training.type " , " finetune_lora " )
2024-07-15 08:35:06 +02:00
model_instance . gguf_writer . add_float32 ( " training.lora.alpha " , float ( alpha ) )
2024-07-08 21:55:41 +02:00
2024-07-10 00:26:38 +02:00
model_instance . gguf_writer . add_quantization_version ( gguf . GGML_QUANT_VERSION )
logger . info ( " Exporting model... " )
model_instance . write ( )
logger . info ( f " Model successfully exported to { model_instance . fname_out } " )