convert-lora : make --base optional (#10110)

* convert-lora : make `--base` optional

* lint

* handle case where base_model_name_or_path is invalid

* do not include metadata from base model

* clarify unspecified --base

* add small comment [no ci]

* trigger ci
This commit is contained in:
Xuan Son Nguyen 2024-11-02 12:53:17 +01:00 committed by GitHub
parent a6744e43e8
commit 7554aa4655
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 23 deletions

View File

@ -72,7 +72,8 @@ class Model:
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False, use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None, metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
if type(self) is Model: if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
@ -87,7 +88,7 @@ class Model:
self.is_safetensors = len(self.part_names) > 0 self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors: if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = Model.load_hparams(self.dir_model) self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None self.tensor_names = None
@ -1541,6 +1542,17 @@ class LlamaModel(Model):
special_vocab._set_special_token("eot", 32010) special_vocab._set_special_token("eot", 32010)
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
# Apply to granite small models only
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
@ -1557,17 +1569,6 @@ class LlamaModel(Model):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
# Apply to granite small models only
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)
@staticmethod @staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv: if n_head_kv is not None and n_head != n_head_kv:

View File

@ -12,6 +12,7 @@ import json
from math import prod from math import prod
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
from transformers import AutoConfig
import torch import torch
@ -256,8 +257,8 @@ def parse_args() -> argparse.Namespace:
help="only print out what will be done, without writing any new files", help="only print out what will be done, without writing any new files",
) )
parser.add_argument( parser.add_argument(
"--base", type=Path, required=True, "--base", type=Path,
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required", help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
) )
parser.add_argument( parser.add_argument(
"lora_path", type=Path, "lora_path", type=Path,
@ -267,6 +268,12 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args() return parser.parse_args()
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
# normally, adapter does not come with base model config, we need to load it from AutoConfig
config = AutoConfig.from_pretrained(hf_model_id)
return config.to_dict()
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
@ -281,7 +288,7 @@ if __name__ == '__main__':
ftype = ftype_map[args.outtype] ftype = ftype_map[args.outtype]
dir_base_model: Path = args.base dir_base_model: Path | None = args.base
dir_lora: Path = args.lora_path dir_lora: Path = args.lora_path
lora_config = dir_lora / "adapter_config.json" lora_config = dir_lora / "adapter_config.json"
input_model = dir_lora / "adapter_model.safetensors" input_model = dir_lora / "adapter_model.safetensors"
@ -301,9 +308,29 @@ if __name__ == '__main__':
input_model = os.path.join(dir_lora, "adapter_model.bin") input_model = os.path.join(dir_lora, "adapter_model.bin")
lora_model = torch.load(input_model, map_location="cpu", weights_only=True) lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
# load LoRA config
with open(lora_config, "r") as f:
lparams: dict[str, Any] = json.load(f)
# load base model # load base model
logger.info(f"Loading base model: {dir_base_model.name}") if dir_base_model is None:
hparams = Model.load_hparams(dir_base_model) if "base_model_name_or_path" in lparams:
model_id = lparams["base_model_name_or_path"]
logger.info(f"Loading base model from Hugging Face: {model_id}")
try:
hparams = load_hparams_from_hf(model_id)
except OSError as e:
logger.error(f"Failed to load base model config: {e}")
logger.error("Please try downloading the base model and add its path to --base")
sys.exit(1)
else:
logger.error("'base_model_name_or_path' is not found in adapter_config.json")
logger.error("Base model config is required. Please download the base model and add its path to --base")
sys.exit(1)
else:
logger.info(f"Loading base model: {dir_base_model.name}")
hparams = Model.load_hparams(dir_base_model)
with torch.inference_mode(): with torch.inference_mode():
try: try:
model_class = Model.from_model_architecture(hparams["architectures"][0]) model_class = Model.from_model_architecture(hparams["architectures"][0])
@ -323,13 +350,15 @@ if __name__ == '__main__':
self.dir_model_card = dir_lora_model self.dir_model_card = dir_lora_model
self.lora_alpha = float(lora_alpha) self.lora_alpha = float(lora_alpha)
def set_vocab(self):
pass
def set_type(self): def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.ADAPTER) self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
super().set_gguf_parameters()
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters # Never add extra tensors (e.g. rope_freqs) for LoRA adapters
@ -350,7 +379,7 @@ if __name__ == '__main__':
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
if ".embed_tokens.weight" in name or ".lm_head.weight" in name: if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning") logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()") logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948")
sys.exit(1) sys.exit(1)
if base_name in tensor_map: if base_name in tensor_map:
@ -384,9 +413,6 @@ if __name__ == '__main__':
yield (dest_name + ".lora_a", lora_a) yield (dest_name + ".lora_a", lora_a)
yield (dest_name + ".lora_b", lora_b) yield (dest_name + ".lora_b", lora_b)
with open(lora_config, "r") as f:
lparams: dict[str, Any] = json.load(f)
alpha: float = lparams["lora_alpha"] alpha: float = lparams["lora_alpha"]
model_instance = LoraModel( model_instance = LoraModel(
@ -399,6 +425,7 @@ if __name__ == '__main__':
dry_run=args.dry_run, dry_run=args.dry_run,
dir_lora_model=dir_lora, dir_lora_model=dir_lora,
lora_alpha=alpha, lora_alpha=alpha,
hparams=hparams,
) )
logger.info("Exporting model...") logger.info("Exporting model...")