mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-30 06:00:15 +01:00
Fixed the param name when loading a LoRA using a model loaded in 4 or 8 bits (#3036)
This commit is contained in:
parent
1f540fa4f8
commit
d7e14e1f78
@ -114,11 +114,12 @@ def add_lora_transformers(lora_names):
|
|||||||
if len(lora_names) > 0:
|
if len(lora_names) > 0:
|
||||||
params = {}
|
params = {}
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
params['dtype'] = shared.model.dtype
|
if shared.args.load_in_4bit or shared.args.load_in_8bit:
|
||||||
if hasattr(shared.model, "hf_device_map"):
|
params['peft_type'] = shared.model.dtype
|
||||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
else:
|
||||||
elif shared.args.load_in_8bit:
|
params['dtype'] = shared.model.dtype
|
||||||
params['device_map'] = {'': 0}
|
if hasattr(shared.model, "hf_device_map"):
|
||||||
|
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||||
|
|
||||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params)
|
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), adapter_name=lora_names[0], **params)
|
||||||
|
Loading…
Reference in New Issue
Block a user