mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-24 02:29:25 +01:00
transformers loader: multi-LoRAs support (#3120)
This commit is contained in:
parent
4405513ca5
commit
d0c3b407b3
@ -2,13 +2,13 @@
|
||||
|
||||
| Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation |
|
||||
|----------------|----------------|-------------------------|----------------|----------------------|-----------------------|
|
||||
| Transformers | ✅ | ❌ | ✅* | ✅ | ✅ |
|
||||
| Transformers | ✅ | ✅ | ✅* | ✅ | ✅ |
|
||||
| ExLlama_HF | ✅ | ❌ | ❌ | ❌ | ✅ |
|
||||
| ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||
| ExLlama | ✅ | ❌ | ❌ | ❌ | use ExLlama_HF |
|
||||
| ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF |
|
||||
| AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ |
|
||||
| GPTQ-for-LLaMa | ✅** | ❌ | ✅ | ✅ | ✅ |
|
||||
| GPTQ-for-LLaMa | ✅** | ✅ | ✅ | ✅ | ✅ |
|
||||
| llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF |
|
||||
| llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ |
|
||||
| ctransformers | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
@ -8,6 +8,14 @@ from modules.logging_colors import logger
|
||||
from modules.models import reload_model
|
||||
|
||||
|
||||
def merge_loras():
|
||||
if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1:
|
||||
logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.")
|
||||
return
|
||||
|
||||
shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged")
|
||||
shared.model.set_adapter("__merged")
|
||||
|
||||
def add_lora_to_model(lora_names):
|
||||
if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ':
|
||||
add_lora_autogptq(lora_names)
|
||||
@ -136,11 +144,14 @@ def add_lora_transformers(lora_names):
|
||||
return
|
||||
|
||||
# Add a LoRA when another LoRA is already present
|
||||
if len(removed_set) == 0 and len(prior_set) > 0:
|
||||
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
|
||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||
for lora in added_set:
|
||||
shared.model.load_adapter(get_lora_path(lora), lora)
|
||||
|
||||
if len(lora_names) > 1:
|
||||
merge_loras()
|
||||
|
||||
return
|
||||
|
||||
# If any LoRA needs to be removed, start over
|
||||
@ -165,6 +176,9 @@ def add_lora_transformers(lora_names):
|
||||
for lora in lora_names[1:]:
|
||||
shared.model.load_adapter(get_lora_path(lora), lora)
|
||||
|
||||
if len(lora_names) > 1:
|
||||
merge_loras()
|
||||
|
||||
shared.lora_names = lora_names
|
||||
|
||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||
|
Loading…
Reference in New Issue
Block a user