mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
transformers loader: multi-LoRAs support (#3120)
This commit is contained in:
parent
4405513ca5
commit
d0c3b407b3
@ -2,13 +2,13 @@
|
|||||||
|
|
||||||
| Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation |
|
| Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation |
|
||||||
|----------------|----------------|-------------------------|----------------|----------------------|-----------------------|
|
|----------------|----------------|-------------------------|----------------|----------------------|-----------------------|
|
||||||
| Transformers | ✅ | ❌ | ✅* | ✅ | ✅ |
|
| Transformers | ✅ | ✅ | ✅* | ✅ | ✅ |
|
||||||
| ExLlama_HF | ✅ | ❌ | ❌ | ❌ | ✅ |
|
| ExLlama_HF | ✅ | ❌ | ❌ | ❌ | ✅ |
|
||||||
| ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ |
|
| ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ |
|
||||||
| ExLlama | ✅ | ❌ | ❌ | ❌ | use ExLlama_HF |
|
| ExLlama | ✅ | ❌ | ❌ | ❌ | use ExLlama_HF |
|
||||||
| ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF |
|
| ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF |
|
||||||
| AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ |
|
| AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ |
|
||||||
| GPTQ-for-LLaMa | ✅** | ❌ | ✅ | ✅ | ✅ |
|
| GPTQ-for-LLaMa | ✅** | ✅ | ✅ | ✅ | ✅ |
|
||||||
| llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF |
|
| llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF |
|
||||||
| llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ |
|
| llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ |
|
||||||
| ctransformers | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| ctransformers | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
@ -8,6 +8,14 @@ from modules.logging_colors import logger
|
|||||||
from modules.models import reload_model
|
from modules.models import reload_model
|
||||||
|
|
||||||
|
|
||||||
|
def merge_loras():
|
||||||
|
if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1:
|
||||||
|
logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.")
|
||||||
|
return
|
||||||
|
|
||||||
|
shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged")
|
||||||
|
shared.model.set_adapter("__merged")
|
||||||
|
|
||||||
def add_lora_to_model(lora_names):
|
def add_lora_to_model(lora_names):
|
||||||
if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ':
|
if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ':
|
||||||
add_lora_autogptq(lora_names)
|
add_lora_autogptq(lora_names)
|
||||||
@ -136,11 +144,14 @@ def add_lora_transformers(lora_names):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Add a LoRA when another LoRA is already present
|
# Add a LoRA when another LoRA is already present
|
||||||
if len(removed_set) == 0 and len(prior_set) > 0:
|
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
|
||||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||||
for lora in added_set:
|
for lora in added_set:
|
||||||
shared.model.load_adapter(get_lora_path(lora), lora)
|
shared.model.load_adapter(get_lora_path(lora), lora)
|
||||||
|
|
||||||
|
if len(lora_names) > 1:
|
||||||
|
merge_loras()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If any LoRA needs to be removed, start over
|
# If any LoRA needs to be removed, start over
|
||||||
@ -165,6 +176,9 @@ def add_lora_transformers(lora_names):
|
|||||||
for lora in lora_names[1:]:
|
for lora in lora_names[1:]:
|
||||||
shared.model.load_adapter(get_lora_path(lora), lora)
|
shared.model.load_adapter(get_lora_path(lora), lora)
|
||||||
|
|
||||||
|
if len(lora_names) > 1:
|
||||||
|
merge_loras()
|
||||||
|
|
||||||
shared.lora_names = lora_names
|
shared.lora_names = lora_names
|
||||||
|
|
||||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||||
|
Loading…
Reference in New Issue
Block a user