mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Add LoRA support to ExLlama_HF
This commit is contained in:
parent
b7c627f9a0
commit
22d455b072
@ -11,7 +11,7 @@ from modules.models import reload_model
|
|||||||
def add_lora_to_model(lora_names):
|
def add_lora_to_model(lora_names):
|
||||||
if 'GPTQForCausalLM' in shared.model.__class__.__name__:
|
if 'GPTQForCausalLM' in shared.model.__class__.__name__:
|
||||||
add_lora_autogptq(lora_names)
|
add_lora_autogptq(lora_names)
|
||||||
elif shared.model.__class__.__name__ == 'ExllamaModel':
|
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF']:
|
||||||
add_lora_exllama(lora_names)
|
add_lora_exllama(lora_names)
|
||||||
else:
|
else:
|
||||||
add_lora_transformers(lora_names)
|
add_lora_transformers(lora_names)
|
||||||
@ -29,7 +29,11 @@ def add_lora_exllama(lora_names):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if len(lora_names) == 0:
|
if len(lora_names) == 0:
|
||||||
|
if shared.model.__class__.__name__ == 'ExllamaModel':
|
||||||
shared.model.generator.lora = None
|
shared.model.generator.lora = None
|
||||||
|
else:
|
||||||
|
shared.model.lora = None
|
||||||
|
|
||||||
shared.lora_names = []
|
shared.lora_names = []
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
@ -41,8 +45,13 @@ def add_lora_exllama(lora_names):
|
|||||||
lora_adapter_path = lora_path / "adapter_model.bin"
|
lora_adapter_path = lora_path / "adapter_model.bin"
|
||||||
|
|
||||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]])))
|
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]])))
|
||||||
|
if shared.model.__class__.__name__ == 'ExllamaModel':
|
||||||
lora = ExLlamaLora(shared.model.model, str(lora_config_path), str(lora_adapter_path))
|
lora = ExLlamaLora(shared.model.model, str(lora_config_path), str(lora_adapter_path))
|
||||||
shared.model.generator.lora = lora
|
shared.model.generator.lora = lora
|
||||||
|
else:
|
||||||
|
lora = ExLlamaLora(shared.model.ex_model, str(lora_config_path), str(lora_adapter_path))
|
||||||
|
shared.model.lora = lora
|
||||||
|
|
||||||
shared.lora_names = [lora_names[0]]
|
shared.lora_names = [lora_names[0]]
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ class ExllamaHF(PreTrainedModel):
|
|||||||
self.ex_config = config
|
self.ex_config = config
|
||||||
self.ex_model = ExLlama(self.ex_config)
|
self.ex_model = ExLlama(self.ex_config)
|
||||||
self.generation_config = GenerationConfig()
|
self.generation_config = GenerationConfig()
|
||||||
|
self.lora = None
|
||||||
|
|
||||||
def _validate_model_class(self):
|
def _validate_model_class(self):
|
||||||
pass
|
pass
|
||||||
@ -53,9 +54,9 @@ class ExllamaHF(PreTrainedModel):
|
|||||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = ExLlamaCache(self.ex_model)
|
cache = ExLlamaCache(self.ex_model)
|
||||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
|
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
|
||||||
|
|
||||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device)
|
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
@ -245,7 +245,7 @@ def create_model_menus():
|
|||||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
||||||
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
|
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
|
||||||
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
|
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
|
||||||
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama and doesn\'t support LoRA.')
|
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
Loading…
Reference in New Issue
Block a user