From 280ae720d7f74bed35aafdd8b32f691b4e806f67 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 23 Oct 2023 13:07:17 -0700 Subject: [PATCH] Organize --- modules/LoRA.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index 338213a3..1f1156cf 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -8,14 +8,6 @@ from modules.logging_colors import logger from modules.models import reload_model -def merge_loras(): - if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1: - logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.") - return - - shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged") - shared.model.set_adapter("__merged") - def add_lora_to_model(lora_names): if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ': add_lora_autogptq(lora_names) @@ -189,3 +181,12 @@ def add_lora_transformers(lora_names): shared.model = shared.model.to(device) else: shared.model = shared.model.cuda() + + +def merge_loras(): + if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1: + logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.") + return + + shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged") + shared.model.set_adapter("__merged")