mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
Implement a demo HF wrapper for exllama to utilize existing HF transformers decoding. (#2777)
This commit is contained in:
parent
a06acd6d09
commit
580c1ee748
@ -212,7 +212,7 @@ Optionally, you can use the following command-line flags:
|
||||
|
||||
| Flag | Description |
|
||||
|--------------------------------------------|-------------|
|
||||
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen |
|
||||
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen |
|
||||
|
||||
#### Accelerate/transformers
|
||||
|
||||
|
82
modules/exllama_hf.py
Normal file
82
modules/exllama_hf.py
Normal file
@ -0,0 +1,82 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import *
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
LlamaTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.relative_imports import RelativeImport
|
||||
|
||||
with RelativeImport("repositories/exllama"):
|
||||
from model import ExLlama, ExLlamaCache, ExLlamaConfig
|
||||
|
||||
|
||||
class ExllamaHF(PreTrainedModel):
|
||||
def __init__(self, config: ExLlamaConfig):
|
||||
super().__init__(PretrainedConfig())
|
||||
self.ex_config = config
|
||||
self.ex_model = ExLlama(self.ex_config)
|
||||
self.generation_config = GenerationConfig()
|
||||
|
||||
def _validate_model_class(self):
|
||||
pass
|
||||
|
||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||
pass
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {'input_ids': input_ids, **kwargs}
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
# TODO: May cause problem on multi-gpu inference?
|
||||
return torch.device(0)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||
assert len(args) == 0, 'no *args should be passed to forward'
|
||||
use_cache = kwargs['use_cache']
|
||||
seq = kwargs['input_ids'][0].tolist()
|
||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
if cache is None:
|
||||
cache = ExLlamaCache(self.ex_model)
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(self.device)
|
||||
return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
|
||||
if isinstance(pretrained_model_name_or_path, str):
|
||||
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
||||
|
||||
pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
|
||||
config = ExLlamaConfig(pretrained_model_name_or_path / 'config.json')
|
||||
|
||||
# from 'oobabooga/text-generation-webui/modules/exllama.py'
|
||||
weight_path = None
|
||||
for ext in ['.safetensors', '.pt', '.bin']:
|
||||
found = list(pretrained_model_name_or_path.glob(f"*{ext}"))
|
||||
if len(found) > 0:
|
||||
weight_path = found[-1]
|
||||
break
|
||||
assert weight_path is not None, f'could not find weight in "{pretrained_model_name_or_path}"'
|
||||
|
||||
config.model_path = str(weight_path)
|
||||
|
||||
# This slowes down a bit but align better with autogptq generation.
|
||||
# TODO: Should give user choice to tune the exllama config
|
||||
config.act_order = True
|
||||
config.fused_attn = False
|
||||
config.fused_mlp_thd = 0
|
||||
|
||||
return ExllamaHF(config)
|
@ -55,6 +55,10 @@ loaders_and_params = {
|
||||
'ExLlama' : [
|
||||
'gpu_split',
|
||||
'exllama_info',
|
||||
],
|
||||
'ExLlama_HF' : [
|
||||
'gpu_split',
|
||||
'exllama_HF_info',
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,8 @@ def load_model(model_name, loader=None):
|
||||
'llama.cpp': llamacpp_loader,
|
||||
'FlexGen': flexgen_loader,
|
||||
'RWKV': RWKV_loader,
|
||||
'ExLlama': ExLlama_loader
|
||||
'ExLlama': ExLlama_loader,
|
||||
'ExLlama_HF': ExLlama_HF_loader
|
||||
}
|
||||
|
||||
if loader is None:
|
||||
@ -278,6 +279,12 @@ def ExLlama_loader(model_name):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def ExLlama_HF_loader(model_name):
|
||||
from modules.exllama_hf import ExllamaHF
|
||||
|
||||
return ExllamaHF.from_pretrained(model_name)
|
||||
|
||||
|
||||
def get_max_memory_dict():
|
||||
max_memory = {}
|
||||
if shared.args.gpu_memory:
|
||||
|
@ -98,7 +98,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
|
||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
|
||||
# Model loader
|
||||
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen')
|
||||
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen')
|
||||
|
||||
# Accelerate/transformers
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||
@ -218,6 +218,8 @@ def fix_loader_name(name):
|
||||
return 'GPTQ-for-LLaMa'
|
||||
elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']:
|
||||
return 'ExLlama'
|
||||
elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']:
|
||||
return 'ExLlama_HF'
|
||||
|
||||
|
||||
if args.loader is not None:
|
||||
|
@ -104,9 +104,8 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
|
||||
else:
|
||||
new_tokens = len(output_ids) - len(input_ids[0])
|
||||
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
|
||||
|
||||
# Prevent LlamaTokenizer from skipping a space
|
||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and len(output_ids) > 0:
|
||||
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
|
||||
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'):
|
||||
reply = ' ' + reply
|
||||
|
||||
|
@ -197,7 +197,7 @@ def create_model_menus():
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "llama.cpp"], value=None)
|
||||
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "ExLlama_HF", "llama.cpp"], value=None)
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -237,6 +237,7 @@ def create_model_menus():
|
||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
||||
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
|
||||
shared.gradio['exllama_info'] = gr.Markdown('ExLlama has to be installed manually. See the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
|
||||
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s still a bit buggy, so feel free to help out by fixing issues.\n\nCheck out PR [#2777](https://github.com/oobabooga/text-generation-webui/pull/2777) for more details.')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
|
Loading…
Reference in New Issue
Block a user