mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
conversion: only allow selected models
This commit is contained in:
parent
03d24cae19
commit
ee2b35c65f
@ -373,6 +373,9 @@ class Model:
|
||||
except KeyError:
|
||||
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
# used for GPT-2 BPE and WordPiece vocabs
|
||||
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
|
||||
tokens: list[str] = []
|
||||
@ -1416,9 +1419,9 @@ class LlamaModel(Model):
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
|
||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||
if name.endswith(("q_proj.weight", "q_proj.bias", "q_proj.lora_B.weight")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||
if name.endswith(("k_proj.weight", "k_proj.bias")):
|
||||
if name.endswith(("k_proj.weight", "k_proj.bias", "k_proj.lora_B.weight")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||
|
||||
# process the experts separately
|
||||
@ -1466,6 +1469,10 @@ class LlamaModel(Model):
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
# TODO: support lora conversion for MOE
|
||||
return "num_local_experts" not in self.hparams
|
||||
|
||||
|
||||
@Model.register("BitnetForCausalLM")
|
||||
class BitnetModel(Model):
|
||||
|
@ -9,7 +9,7 @@ import os
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Iterable, Iterator
|
||||
from typing import TYPE_CHECKING, Iterator
|
||||
|
||||
import torch
|
||||
|
||||
@ -26,6 +26,13 @@ from convert_hf_to_gguf import Model
|
||||
logger = logging.getLogger("lora-to-gguf")
|
||||
|
||||
|
||||
def get_base_tensor_name(lora_tensor_name: str) -> str:
|
||||
base_name = lora_tensor_name.replace("base_model.model.", "")
|
||||
base_name = base_name.replace(".lora_A.weight", ".weight")
|
||||
base_name = base_name.replace(".lora_B.weight", ".weight")
|
||||
return base_name
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
|
||||
@ -103,43 +110,47 @@ if __name__ == '__main__':
|
||||
|
||||
# adapter_config = json.load(input_json)
|
||||
model_instance.gguf_writer.add_string("training.type", "finetune_lora")
|
||||
if not model_instance.support_lora():
|
||||
logger.error("LoRA conversion is not yet supported for this model")
|
||||
sys.exit(1)
|
||||
|
||||
map_tensors: dict[str, Tensor] = {}
|
||||
# map original name to gguf name
|
||||
map_name: dict[str, str] = {}
|
||||
for tensor_name, tensor in lora_model.items():
|
||||
orig_name = tensor_name.replace("base_model.model.", "")
|
||||
orig_name = orig_name.replace(".lora_A.weight", ".weight")
|
||||
orig_name = orig_name.replace(".lora_B.weight", ".weight")
|
||||
base_name = get_base_tensor_name(tensor_name)
|
||||
is_lora_a = ".lora_A.weight" in tensor_name
|
||||
is_lora_b = ".lora_B.weight" in tensor_name
|
||||
if not is_lora_a and not is_lora_b:
|
||||
logger.error(f"Unexpected name '{tensor_name}': Not a lora_A or lora_B tensor")
|
||||
sys.exit(1)
|
||||
dest_name = model_instance.map_tensor_name(orig_name)
|
||||
dest_name = model_instance.map_tensor_name(base_name)
|
||||
dest_name = f"{dest_name}.lora_a" if is_lora_a else f"{dest_name}.lora_b"
|
||||
# logger.info(f"{orig_name} --> {dest_name}")
|
||||
map_tensors[dest_name] = tensor
|
||||
map_name[tensor_name] = dest_name
|
||||
|
||||
# overwrite method
|
||||
def map_tensor_name(self, name: str) -> Iterator[tuple[str, Tensor]]:
|
||||
return map_name[name]
|
||||
|
||||
# overwrite method
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for name, tensor in map_tensors.items():
|
||||
for name, tensor in lora_model.items():
|
||||
yield (name, tensor)
|
||||
|
||||
# overwrite method
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
# TODO: This will not take into account tensor transformations
|
||||
return [(name, data_torch)]
|
||||
|
||||
# overwrite method
|
||||
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
return ftype != gguf.LlamaFileType.ALL_F32
|
||||
|
||||
model_instance._map_tensor_name = model_instance.map_tensor_name
|
||||
model_instance.map_tensor_name = types.MethodType(map_tensor_name, model_instance)
|
||||
|
||||
model_instance._get_tensors = model_instance.get_tensors
|
||||
model_instance.get_tensors = types.MethodType(get_tensors, model_instance)
|
||||
model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance)
|
||||
|
||||
model_instance._extra_f16_tensors = model_instance.extra_f16_tensors
|
||||
model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance)
|
||||
|
||||
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
||||
logger.info("Exporting model...")
|
||||
model_instance.write()
|
||||
logger.info(f"Model successfully exported to {fname_out}")
|
||||
logger.info(f"Model successfully exported to {model_instance.fname_out}")
|
||||
|
Loading…
Reference in New Issue
Block a user