mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-29 21:50:16 +01:00
Add ExLlama support (#2444)
This commit is contained in:
parent
dea43685b0
commit
9f40032d32
@ -18,7 +18,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
||||
## Features
|
||||
|
||||
* 3 interface modes: default, notebook, and chat
|
||||
* Multiple model backends: tranformers, llama.cpp, AutoGPTQ, GPTQ-for-LLaMa, RWKV, FlexGen
|
||||
* Multiple model backends: tranformers, llama.cpp, AutoGPTQ, GPTQ-for-LLaMa, ExLlama, RWKV, FlexGen
|
||||
* Dropdown menu for quickly switching between different models
|
||||
* LoRA: load and unload LoRAs on the fly, load multiple LoRAs at the same time, train a new LoRA
|
||||
* Precise instruction templates for chat mode, including Alpaca, Vicuna, Open Assistant, Dolly, Koala, ChatGLM, MOSS, RWKV-Raven, Galactica, StableLM, WizardLM, Baize, Ziya, Chinese-Vicuna, MPT, INCITE, Wizard Mega, KoAlpaca, Vigogne, Bactrian, h2o, and OpenBuddy
|
||||
@ -215,7 +215,7 @@ Optionally, you can use the following command-line flags:
|
||||
|
||||
| Flag | Description |
|
||||
|--------------------------------------------|-------------|
|
||||
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen |
|
||||
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen |
|
||||
|
||||
#### Accelerate/transformers
|
||||
|
||||
|
16
docs/ExLlama.md
Normal file
16
docs/ExLlama.md
Normal file
@ -0,0 +1,16 @@
|
||||
# ExLlama
|
||||
|
||||
## About
|
||||
|
||||
ExLlama is an extremely optimized GPTQ backend for LLaMA models. It features much lower VRAM usage and much higher speeds due to not relying on unoptimized transformers code.
|
||||
|
||||
# Installation:
|
||||
|
||||
1) Clone the ExLlama repository into your `repositories` folder:
|
||||
|
||||
```
|
||||
cd repositories
|
||||
git clone https://github.com/turboderp/exllama
|
||||
```
|
||||
|
||||
2) Follow the remaining set up instructions in the official README: https://github.com/turboderp/exllama#exllama
|
@ -38,31 +38,31 @@ class RWKVModel:
|
||||
result.cached_output_logits = None
|
||||
return result
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=None, token_stop=None, callback=None):
|
||||
def generate(self, prompt, state, callback=None):
|
||||
args = PIPELINE_ARGS(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3)
|
||||
token_ban=token_ban or [0], # ban the generation of some tokens
|
||||
token_stop=token_stop or []
|
||||
temperature=state['temperature'],
|
||||
top_p=state['top_p'],
|
||||
top_k=state['top_k'],
|
||||
alpha_frequency=0.1, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence=0.1, # Presence Penalty (as in GPT-3)
|
||||
token_ban=[0], # ban the generation of some tokens
|
||||
token_stop=[]
|
||||
)
|
||||
|
||||
if self.cached_context != "":
|
||||
if context.startswith(self.cached_context):
|
||||
context = context[len(self.cached_context):]
|
||||
if prompt.startswith(self.cached_context):
|
||||
prompt = prompt[len(self.cached_context):]
|
||||
else:
|
||||
self.cached_context = ""
|
||||
self.cached_model_state = None
|
||||
self.cached_output_logits = None
|
||||
|
||||
# out = self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
|
||||
out = self.generate_from_cached_state(context, token_count=token_count, args=args, callback=callback)
|
||||
# out = self.pipeline.generate(prompt, token_count=state['max_new_tokens'], args=args, callback=callback)
|
||||
out = self.generate_from_cached_state(prompt, token_count=state['max_new_tokens'], args=args, callback=callback)
|
||||
return out
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
def generate_with_streaming(self, *args, **kwargs):
|
||||
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
||||
reply = ''
|
||||
for token in generator:
|
||||
reply += token
|
||||
@ -81,6 +81,7 @@ class RWKVModel:
|
||||
if ctx == "":
|
||||
out = self.cached_output_logits
|
||||
|
||||
token = None
|
||||
for i in range(token_count):
|
||||
# forward
|
||||
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
|
||||
|
@ -55,11 +55,12 @@ class Iteratorize:
|
||||
Adapted from: https://stackoverflow.com/a/9969000
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs=None, callback=None):
|
||||
def __init__(self, func, args=None, kwargs=None, callback=None):
|
||||
self.mfunc = func
|
||||
self.c_callback = callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.args = args or []
|
||||
self.kwargs = kwargs or {}
|
||||
self.stop_now = False
|
||||
|
||||
@ -70,7 +71,7 @@ class Iteratorize:
|
||||
|
||||
def gentask():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
|
||||
except ValueError:
|
||||
pass
|
||||
except:
|
||||
|
81
modules/exllama.py
Normal file
81
modules/exllama.py
Normal file
@ -0,0 +1,81 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path("repositories/exllama")))
|
||||
|
||||
from modules.logging_colors import logger
|
||||
from repositories.exllama.generator import ExLlamaGenerator
|
||||
from repositories.exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
|
||||
from repositories.exllama.tokenizer import ExLlamaTokenizer
|
||||
|
||||
|
||||
class ExllamaModel:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, path_to_model):
|
||||
|
||||
path_to_model = Path("models") / Path(path_to_model)
|
||||
tokenizer_model_path = path_to_model / "tokenizer.model"
|
||||
model_config_path = path_to_model / "config.json"
|
||||
|
||||
# Find the model checkpoint
|
||||
model_path = None
|
||||
for ext in ['.safetensors', '.pt', '.bin']:
|
||||
found = list(path_to_model.glob(f"*{ext}"))
|
||||
if len(found) > 0:
|
||||
if len(found) > 1:
|
||||
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
||||
|
||||
model_path = found[-1]
|
||||
break
|
||||
|
||||
config = ExLlamaConfig(str(model_config_path))
|
||||
config.model_path = str(model_path)
|
||||
model = ExLlama(config)
|
||||
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
|
||||
cache = ExLlamaCache(model)
|
||||
|
||||
result = self()
|
||||
result.config = config
|
||||
result.model = model
|
||||
result.cache = cache
|
||||
result.tokenizer = tokenizer
|
||||
return result, result
|
||||
|
||||
def generate(self, prompt, state, callback=None):
|
||||
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
generator.settings.temperature = state['temperature']
|
||||
generator.settings.top_p = state['top_p']
|
||||
generator.settings.top_k = state['top_k']
|
||||
generator.settings.typical = state['typical_p']
|
||||
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
if state['ban_eos_token']:
|
||||
generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
|
||||
text = generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
|
||||
return text
|
||||
|
||||
def generate_with_streaming(self, prompt, state, callback=None):
|
||||
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
generator.settings.temperature = state['temperature']
|
||||
generator.settings.top_p = state['top_p']
|
||||
generator.settings.top_k = state['top_k']
|
||||
generator.settings.typical = state['typical_p']
|
||||
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
if state['ban_eos_token']:
|
||||
generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
|
||||
generator.end_beam_search()
|
||||
ids = generator.tokenizer.encode(prompt)
|
||||
generator.gen_begin(ids)
|
||||
initial_len = generator.sequence[0].shape[0]
|
||||
for i in range(state['max_new_tokens']):
|
||||
token = generator.gen_single_token()
|
||||
yield (generator.tokenizer.decode(generator.sequence[0][initial_len:]))
|
||||
if token.item() == generator.tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string)
|
@ -59,18 +59,18 @@ class LlamaCppModel:
|
||||
|
||||
return self.model.tokenize(string)
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
|
||||
context = context if type(context) is str else context.decode()
|
||||
def generate(self, prompt, state, callback=None):
|
||||
prompt = prompt if type(prompt) is str else prompt.decode()
|
||||
completion_chunks = self.model.create_completion(
|
||||
prompt=context,
|
||||
max_tokens=token_count,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repeat_penalty=repetition_penalty,
|
||||
mirostat_mode=int(mirostat_mode),
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
prompt=prompt,
|
||||
max_tokens=state['max_new_tokens'],
|
||||
temperature=state['temperature'],
|
||||
top_p=state['top_p'],
|
||||
top_k=state['top_k'],
|
||||
repeat_penalty=state['repetition_penalty'],
|
||||
mirostat_mode=int(state['mirostat_mode']),
|
||||
mirostat_tau=state['mirostat_tau'],
|
||||
mirostat_eta=state['mirostat_eta'],
|
||||
stream=True
|
||||
)
|
||||
|
||||
@ -83,8 +83,8 @@ class LlamaCppModel:
|
||||
|
||||
return output
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
def generate_with_streaming(self, *args, **kwargs):
|
||||
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
||||
reply = ''
|
||||
for token in generator:
|
||||
reply += token
|
||||
|
@ -52,6 +52,9 @@ loaders_and_params = {
|
||||
'trust_remote_code',
|
||||
'transformers_info'
|
||||
],
|
||||
'ExLlama' : [
|
||||
'exllama_info',
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
|
@ -48,7 +48,8 @@ def load_model(model_name, loader=None):
|
||||
'GPTQ-for-LLaMa': GPTQ_loader,
|
||||
'llama.cpp': llamacpp_loader,
|
||||
'FlexGen': flexgen_loader,
|
||||
'RWKV': RWKV_loader
|
||||
'RWKV': RWKV_loader,
|
||||
'ExLlama': ExLlama_loader
|
||||
}
|
||||
|
||||
if loader is None:
|
||||
@ -270,6 +271,13 @@ def AutoGPTQ_loader(model_name):
|
||||
return modules.AutoGPTQ_loader.load_quantized(model_name)
|
||||
|
||||
|
||||
def ExLlama_loader(model_name):
|
||||
from modules.exllama import ExllamaModel
|
||||
|
||||
model, tokenizer = ExllamaModel.from_pretrained(model_name)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_max_memory_dict():
|
||||
max_memory = {}
|
||||
if shared.args.gpu_memory:
|
||||
|
@ -94,7 +94,7 @@ def apply_model_settings_to_state(model, state):
|
||||
loader = 'AutoGPTQ'
|
||||
|
||||
# If the user is using an alternative GPTQ loader, let them keep using it
|
||||
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'exllama']):
|
||||
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama']):
|
||||
state['loader'] = loader
|
||||
|
||||
for k in model_settings:
|
||||
|
@ -97,7 +97,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
|
||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
|
||||
# Model loader
|
||||
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen')
|
||||
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen')
|
||||
|
||||
# Accelerate/transformers
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||
@ -212,6 +212,8 @@ def fix_loader_name(name):
|
||||
return 'AutoGPTQ'
|
||||
elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']:
|
||||
return 'GPTQ-for-LLaMa'
|
||||
elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']:
|
||||
return 'ExLlama'
|
||||
|
||||
|
||||
if args.loader is not None:
|
||||
|
@ -51,7 +51,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel'] or shared.args.cpu:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
elif shared.args.flexgen:
|
||||
return input_ids.numpy()
|
||||
@ -157,7 +157,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
|
||||
yield ''
|
||||
return
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
||||
generate_func = generate_reply_custom
|
||||
elif shared.args.flexgen:
|
||||
generate_func = generate_reply_flexgen
|
||||
@ -283,13 +283,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
||||
|
||||
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
|
||||
seed = set_manual_seed(state['seed'])
|
||||
generate_params = {'token_count': state['max_new_tokens']}
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel']:
|
||||
for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
t0 = time.time()
|
||||
reply = ''
|
||||
@ -298,13 +291,13 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
|
||||
yield ''
|
||||
|
||||
if not state['stream']:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
reply = shared.model.generate(question, state)
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply)
|
||||
|
||||
yield reply
|
||||
else:
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
for reply in shared.model.generate_with_streaming(question, state):
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply)
|
||||
|
||||
|
@ -77,7 +77,10 @@ def load_model_wrapper(selected_model, loader, autoload=False):
|
||||
else:
|
||||
yield f"Failed to load {selected_model}."
|
||||
except:
|
||||
yield traceback.format_exc()
|
||||
exc = traceback.format_exc()
|
||||
logger.error('Failed to load the model.')
|
||||
print(exc)
|
||||
yield exc
|
||||
|
||||
|
||||
def load_lora_wrapper(selected_loras):
|
||||
@ -193,7 +196,7 @@ def create_model_menus():
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp"], value=None)
|
||||
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "llama.cpp"], value=None)
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -213,6 +216,7 @@ def create_model_menus():
|
||||
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
|
||||
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
|
||||
shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.')
|
||||
shared.gradio['exllama_info'] = gr.Markdown('ExLlama has to be installed manually. See the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama')
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
|
||||
|
Loading…
Reference in New Issue
Block a user