From a6f1e1bcc51ce9b0db62f095f23d329166a6ce9a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 19 Nov 2023 07:55:25 -0800 Subject: [PATCH] Fix PEFT LoRA unloading --- modules/LoRA.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index 4b119994..9c6edbf3 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -149,10 +149,7 @@ def add_lora_transformers(lora_names): # If any LoRA needs to be removed, start over if len(removed_set) > 0: - # shared.model may no longer be PeftModel - if hasattr(shared.model, 'disable_adapter'): - shared.model.disable_adapter() - shared.model = shared.model.base_model.model + shared.model = shared.model.unload() if len(lora_names) > 0: params = {} @@ -172,8 +169,6 @@ def add_lora_transformers(lora_names): if len(lora_names) > 1: merge_loras() - shared.lora_names = lora_names - if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): @@ -186,6 +181,8 @@ def add_lora_transformers(lora_names): else: shared.model = shared.model.cuda() + shared.lora_names = lora_names + def merge_loras(): if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1: