From 11f38b5c2b3fb6feeebdafa102a0a88e5f647a8b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 5 Jun 2023 23:29:29 -0300 Subject: [PATCH] Add AutoGPTQ LoRA support --- modules/LoRA.py | 94 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 33 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index 56f90771..bbee9440 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,10 +1,13 @@ from pathlib import Path import torch +from auto_gptq import get_gptq_peft_model +from auto_gptq.utils.peft_utils import GPTQLoraConfig from peft import PeftModel import modules.shared as shared from modules.logging_colors import logger +from modules.models import reload_model def add_lora_to_model(lora_names): @@ -13,43 +16,68 @@ def add_lora_to_model(lora_names): removed_set = prior_set - set(lora_names) shared.lora_names = list(lora_names) - # If no LoRA needs to be added or removed, exit - if len(added_set) == 0 and len(removed_set) == 0: - return + is_autogptq = 'GPTQForCausalLM' in shared.model.__class__.__name__ - # Add a LoRA when another LoRA is already present - if len(removed_set) == 0 and len(prior_set) > 0: - logger.info(f"Adding the LoRA(s) named {added_set} to the model...") - for lora in added_set: - shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) + # AutoGPTQ case. It doesn't use the peft functions. + # Copied from https://github.com/Ph0rk0z/text-generation-webui-testing + if is_autogptq: + if len(prior_set) > 0: + reload_model() - return + if len(shared.lora_names) == 0: + return + else: + if len(shared.lora_names) > 1: + logger.warning('AutoGPTQ can only work with 1 LoRA at the moment. Only the first one in the list will be loaded') - # If any LoRA needs to be removed, start over - if len(removed_set) > 0: - shared.model.disable_adapter() - shared.model = shared.model.base_model.model + peft_config = GPTQLoraConfig( + inference_mode=True, + ) - if len(lora_names) > 0: - logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) - params = {} - if not shared.args.cpu: - params['dtype'] = shared.model.dtype - if hasattr(shared.model, "hf_device_map"): - params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()} - elif shared.args.load_in_8bit: - params['device_map'] = {'': 0} + lora_path = Path(f"{shared.args.lora_dir}/{shared.lora_names[0]}") + logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]]))) + shared.model = get_gptq_peft_model(shared.model, peft_config, lora_path) + return - shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), **params) + # Transformers case + else: + # If no LoRA needs to be added or removed, exit + if len(added_set) == 0 and len(removed_set) == 0: + return - for lora in lora_names[1:]: - shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) + # Add a LoRA when another LoRA is already present + if len(removed_set) == 0 and len(prior_set) > 0: + logger.info(f"Adding the LoRA(s) named {added_set} to the model...") + for lora in added_set: + shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) - if not shared.args.load_in_8bit and not shared.args.cpu: - shared.model.half() - if not hasattr(shared.model, "hf_device_map"): - if torch.has_mps: - device = torch.device('mps') - shared.model = shared.model.to(device) - else: - shared.model = shared.model.cuda() + return + + # If any LoRA needs to be removed, start over + if len(removed_set) > 0: + shared.model.disable_adapter() + shared.model = shared.model.base_model.model + + if len(lora_names) > 0: + logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) + params = {} + if not shared.args.cpu: + params['dtype'] = shared.model.dtype + if hasattr(shared.model, "hf_device_map"): + params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()} + elif shared.args.load_in_8bit: + params['device_map'] = {'': 0} + + shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), **params) + + for lora in lora_names[1:]: + shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) + + if not shared.args.load_in_8bit and not shared.args.cpu: + shared.model.half() + if not hasattr(shared.model, "hf_device_map"): + if torch.has_mps: + device = torch.device('mps') + shared.model = shared.model.to(device) + else: + shared.model = shared.model.cuda()