mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fixed the param name when loading a LoRA using a model loaded in 4 or 8 bits (#3036)
This commit is contained in:
parent
1f540fa4f8
commit
d7e14e1f78
@ -114,11 +114,12 @@ def add_lora_transformers(lora_names):
|
||||
if len(lora_names) > 0:
|
||||
params = {}
|
||||
if not shared.args.cpu:
|
||||
params['dtype'] = shared.model.dtype
|
||||
if hasattr(shared.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||
elif shared.args.load_in_8bit:
|
||||
params['device_map'] = {'': 0}
|
||||
if shared.args.load_in_4bit or shared.args.load_in_8bit:
|
||||
params['peft_type'] = shared.model.dtype
|
||||
else:
|
||||
params['dtype'] = shared.model.dtype
|
||||
if hasattr(shared.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||
|
||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params)
|
||||
|
Loading…
Reference in New Issue
Block a user