mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix PEFT LoRA unloading
This commit is contained in:
parent
a290d17386
commit
a6f1e1bcc5
@ -149,10 +149,7 @@ def add_lora_transformers(lora_names):
|
|||||||
|
|
||||||
# If any LoRA needs to be removed, start over
|
# If any LoRA needs to be removed, start over
|
||||||
if len(removed_set) > 0:
|
if len(removed_set) > 0:
|
||||||
# shared.model may no longer be PeftModel
|
shared.model = shared.model.unload()
|
||||||
if hasattr(shared.model, 'disable_adapter'):
|
|
||||||
shared.model.disable_adapter()
|
|
||||||
shared.model = shared.model.base_model.model
|
|
||||||
|
|
||||||
if len(lora_names) > 0:
|
if len(lora_names) > 0:
|
||||||
params = {}
|
params = {}
|
||||||
@ -172,8 +169,6 @@ def add_lora_transformers(lora_names):
|
|||||||
if len(lora_names) > 1:
|
if len(lora_names) > 1:
|
||||||
merge_loras()
|
merge_loras()
|
||||||
|
|
||||||
shared.lora_names = lora_names
|
|
||||||
|
|
||||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||||
shared.model.half()
|
shared.model.half()
|
||||||
if not hasattr(shared.model, "hf_device_map"):
|
if not hasattr(shared.model, "hf_device_map"):
|
||||||
@ -186,6 +181,8 @@ def add_lora_transformers(lora_names):
|
|||||||
else:
|
else:
|
||||||
shared.model = shared.model.cuda()
|
shared.model = shared.model.cuda()
|
||||||
|
|
||||||
|
shared.lora_names = lora_names
|
||||||
|
|
||||||
|
|
||||||
def merge_loras():
|
def merge_loras():
|
||||||
if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1:
|
if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1:
|
||||||
|
Loading…
Reference in New Issue
Block a user