mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-24 13:28:59 +01:00
Add AutoGPTQ LoRA support
This commit is contained in:
parent
3a5cfe96f0
commit
11f38b5c2b
@ -1,10 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from auto_gptq import get_gptq_peft_model
|
||||
from auto_gptq.utils.peft_utils import GPTQLoraConfig
|
||||
from peft import PeftModel
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import reload_model
|
||||
|
||||
|
||||
def add_lora_to_model(lora_names):
|
||||
@ -13,43 +16,68 @@ def add_lora_to_model(lora_names):
|
||||
removed_set = prior_set - set(lora_names)
|
||||
shared.lora_names = list(lora_names)
|
||||
|
||||
# If no LoRA needs to be added or removed, exit
|
||||
if len(added_set) == 0 and len(removed_set) == 0:
|
||||
return
|
||||
is_autogptq = 'GPTQForCausalLM' in shared.model.__class__.__name__
|
||||
|
||||
# Add a LoRA when another LoRA is already present
|
||||
if len(removed_set) == 0 and len(prior_set) > 0:
|
||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||
for lora in added_set:
|
||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||
# AutoGPTQ case. It doesn't use the peft functions.
|
||||
# Copied from https://github.com/Ph0rk0z/text-generation-webui-testing
|
||||
if is_autogptq:
|
||||
if len(prior_set) > 0:
|
||||
reload_model()
|
||||
|
||||
return
|
||||
if len(shared.lora_names) == 0:
|
||||
return
|
||||
else:
|
||||
if len(shared.lora_names) > 1:
|
||||
logger.warning('AutoGPTQ can only work with 1 LoRA at the moment. Only the first one in the list will be loaded')
|
||||
|
||||
# If any LoRA needs to be removed, start over
|
||||
if len(removed_set) > 0:
|
||||
shared.model.disable_adapter()
|
||||
shared.model = shared.model.base_model.model
|
||||
peft_config = GPTQLoraConfig(
|
||||
inference_mode=True,
|
||||
)
|
||||
|
||||
if len(lora_names) > 0:
|
||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||
params = {}
|
||||
if not shared.args.cpu:
|
||||
params['dtype'] = shared.model.dtype
|
||||
if hasattr(shared.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||
elif shared.args.load_in_8bit:
|
||||
params['device_map'] = {'': 0}
|
||||
lora_path = Path(f"{shared.args.lora_dir}/{shared.lora_names[0]}")
|
||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]])))
|
||||
shared.model = get_gptq_peft_model(shared.model, peft_config, lora_path)
|
||||
return
|
||||
|
||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), **params)
|
||||
# Transformers case
|
||||
else:
|
||||
# If no LoRA needs to be added or removed, exit
|
||||
if len(added_set) == 0 and len(removed_set) == 0:
|
||||
return
|
||||
|
||||
for lora in lora_names[1:]:
|
||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||
# Add a LoRA when another LoRA is already present
|
||||
if len(removed_set) == 0 and len(prior_set) > 0:
|
||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||
for lora in added_set:
|
||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||
|
||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||
shared.model.half()
|
||||
if not hasattr(shared.model, "hf_device_map"):
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
shared.model = shared.model.to(device)
|
||||
else:
|
||||
shared.model = shared.model.cuda()
|
||||
return
|
||||
|
||||
# If any LoRA needs to be removed, start over
|
||||
if len(removed_set) > 0:
|
||||
shared.model.disable_adapter()
|
||||
shared.model = shared.model.base_model.model
|
||||
|
||||
if len(lora_names) > 0:
|
||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||
params = {}
|
||||
if not shared.args.cpu:
|
||||
params['dtype'] = shared.model.dtype
|
||||
if hasattr(shared.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||
elif shared.args.load_in_8bit:
|
||||
params['device_map'] = {'': 0}
|
||||
|
||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), **params)
|
||||
|
||||
for lora in lora_names[1:]:
|
||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||
|
||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||
shared.model.half()
|
||||
if not hasattr(shared.model, "hf_device_map"):
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
shared.model = shared.model.to(device)
|
||||
else:
|
||||
shared.model = shared.model.cuda()
|
||||
|
Loading…
Reference in New Issue
Block a user