diff --git a/docs/What Works.md b/docs/What Works.md index e8a603c8..86936039 100644 --- a/docs/What Works.md +++ b/docs/What Works.md @@ -2,13 +2,13 @@ | Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation | |----------------|----------------|-------------------------|----------------|----------------------|-----------------------| -| Transformers | ✅ | ❌ | ✅* | ✅ | ✅ | +| Transformers | ✅ | ✅ | ✅* | ✅ | ✅ | | ExLlama_HF | ✅ | ❌ | ❌ | ❌ | ✅ | | ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ | | ExLlama | ✅ | ❌ | ❌ | ❌ | use ExLlama_HF | | ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF | | AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ | -| GPTQ-for-LLaMa | ✅** | ❌ | ✅ | ✅ | ✅ | +| GPTQ-for-LLaMa | ✅** | ✅ | ✅ | ✅ | ✅ | | llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF | | llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ | | ctransformers | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/modules/LoRA.py b/modules/LoRA.py index b3997d80..338213a3 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -8,6 +8,14 @@ from modules.logging_colors import logger from modules.models import reload_model +def merge_loras(): + if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1: + logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.") + return + + shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged") + shared.model.set_adapter("__merged") + def add_lora_to_model(lora_names): if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ': add_lora_autogptq(lora_names) @@ -136,11 +144,14 @@ def add_lora_transformers(lora_names): return # Add a LoRA when another LoRA is already present - if len(removed_set) == 0 and len(prior_set) > 0: + if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys(): logger.info(f"Adding the LoRA(s) named {added_set} to the model...") for lora in added_set: shared.model.load_adapter(get_lora_path(lora), lora) + if len(lora_names) > 1: + merge_loras() + return # If any LoRA needs to be removed, start over @@ -165,6 +176,9 @@ def add_lora_transformers(lora_names): for lora in lora_names[1:]: shared.model.load_adapter(get_lora_path(lora), lora) + if len(lora_names) > 1: + merge_loras() + shared.lora_names = lora_names if not shared.args.load_in_8bit and not shared.args.cpu: