Add ExLlama support (#2444)

This commit is contained in:
oobabooga 2023-06-16 20:35:38 -03:00 committed by GitHub
parent dea43685b0
commit 9f40032d32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 156 additions and 47 deletions

View File

@ -18,7 +18,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
## Features ## Features
* 3 interface modes: default, notebook, and chat * 3 interface modes: default, notebook, and chat
* Multiple model backends: tranformers, llama.cpp, AutoGPTQ, GPTQ-for-LLaMa, RWKV, FlexGen * Multiple model backends: tranformers, llama.cpp, AutoGPTQ, GPTQ-for-LLaMa, ExLlama, RWKV, FlexGen
* Dropdown menu for quickly switching between different models * Dropdown menu for quickly switching between different models
* LoRA: load and unload LoRAs on the fly, load multiple LoRAs at the same time, train a new LoRA * LoRA: load and unload LoRAs on the fly, load multiple LoRAs at the same time, train a new LoRA
* Precise instruction templates for chat mode, including Alpaca, Vicuna, Open Assistant, Dolly, Koala, ChatGLM, MOSS, RWKV-Raven, Galactica, StableLM, WizardLM, Baize, Ziya, Chinese-Vicuna, MPT, INCITE, Wizard Mega, KoAlpaca, Vigogne, Bactrian, h2o, and OpenBuddy * Precise instruction templates for chat mode, including Alpaca, Vicuna, Open Assistant, Dolly, Koala, ChatGLM, MOSS, RWKV-Raven, Galactica, StableLM, WizardLM, Baize, Ziya, Chinese-Vicuna, MPT, INCITE, Wizard Mega, KoAlpaca, Vigogne, Bactrian, h2o, and OpenBuddy
@ -215,7 +215,7 @@ Optionally, you can use the following command-line flags:
| Flag | Description | | Flag | Description |
|--------------------------------------------|-------------| |--------------------------------------------|-------------|
| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen | | `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen |
#### Accelerate/transformers #### Accelerate/transformers

16
docs/ExLlama.md Normal file
View File

@ -0,0 +1,16 @@
# ExLlama
## About
ExLlama is an extremely optimized GPTQ backend for LLaMA models. It features much lower VRAM usage and much higher speeds due to not relying on unoptimized transformers code.
# Installation:
1) Clone the ExLlama repository into your `repositories` folder:
```
cd repositories
git clone https://github.com/turboderp/exllama
```
2) Follow the remaining set up instructions in the official README: https://github.com/turboderp/exllama#exllama

View File

@ -38,31 +38,31 @@ class RWKVModel:
result.cached_output_logits = None result.cached_output_logits = None
return result return result
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=None, token_stop=None, callback=None): def generate(self, prompt, state, callback=None):
args = PIPELINE_ARGS( args = PIPELINE_ARGS(
temperature=temperature, temperature=state['temperature'],
top_p=top_p, top_p=state['top_p'],
top_k=top_k, top_k=state['top_k'],
alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3) alpha_frequency=0.1, # Frequency Penalty (as in GPT-3)
alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3) alpha_presence=0.1, # Presence Penalty (as in GPT-3)
token_ban=token_ban or [0], # ban the generation of some tokens token_ban=[0], # ban the generation of some tokens
token_stop=token_stop or [] token_stop=[]
) )
if self.cached_context != "": if self.cached_context != "":
if context.startswith(self.cached_context): if prompt.startswith(self.cached_context):
context = context[len(self.cached_context):] prompt = prompt[len(self.cached_context):]
else: else:
self.cached_context = "" self.cached_context = ""
self.cached_model_state = None self.cached_model_state = None
self.cached_output_logits = None self.cached_output_logits = None
# out = self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) # out = self.pipeline.generate(prompt, token_count=state['max_new_tokens'], args=args, callback=callback)
out = self.generate_from_cached_state(context, token_count=token_count, args=args, callback=callback) out = self.generate_from_cached_state(prompt, token_count=state['max_new_tokens'], args=args, callback=callback)
return out return out
def generate_with_streaming(self, **kwargs): def generate_with_streaming(self, *args, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator: with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
reply = '' reply = ''
for token in generator: for token in generator:
reply += token reply += token
@ -81,6 +81,7 @@ class RWKVModel:
if ctx == "": if ctx == "":
out = self.cached_output_logits out = self.cached_output_logits
token = None
for i in range(token_count): for i in range(token_count):
# forward # forward
tokens = self.pipeline.encode(ctx) if i == 0 else [token] tokens = self.pipeline.encode(ctx) if i == 0 else [token]

View File

@ -55,11 +55,12 @@ class Iteratorize:
Adapted from: https://stackoverflow.com/a/9969000 Adapted from: https://stackoverflow.com/a/9969000
""" """
def __init__(self, func, kwargs=None, callback=None): def __init__(self, func, args=None, kwargs=None, callback=None):
self.mfunc = func self.mfunc = func
self.c_callback = callback self.c_callback = callback
self.q = Queue() self.q = Queue()
self.sentinel = object() self.sentinel = object()
self.args = args or []
self.kwargs = kwargs or {} self.kwargs = kwargs or {}
self.stop_now = False self.stop_now = False
@ -70,7 +71,7 @@ class Iteratorize:
def gentask(): def gentask():
try: try:
ret = self.mfunc(callback=_callback, **self.kwargs) ret = self.mfunc(callback=_callback, *args, **self.kwargs)
except ValueError: except ValueError:
pass pass
except: except:

81
modules/exllama.py Normal file
View File

@ -0,0 +1,81 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path("repositories/exllama")))
from modules.logging_colors import logger
from repositories.exllama.generator import ExLlamaGenerator
from repositories.exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from repositories.exllama.tokenizer import ExLlamaTokenizer
class ExllamaModel:
def __init__(self):
pass
@classmethod
def from_pretrained(self, path_to_model):
path_to_model = Path("models") / Path(path_to_model)
tokenizer_model_path = path_to_model / "tokenizer.model"
model_config_path = path_to_model / "config.json"
# Find the model checkpoint
model_path = None
for ext in ['.safetensors', '.pt', '.bin']:
found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0:
if len(found) > 1:
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
model_path = found[-1]
break
config = ExLlamaConfig(str(model_config_path))
config.model_path = str(model_path)
model = ExLlama(config)
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
cache = ExLlamaCache(model)
result = self()
result.config = config
result.model = model
result.cache = cache
result.tokenizer = tokenizer
return result, result
def generate(self, prompt, state, callback=None):
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
generator.settings.temperature = state['temperature']
generator.settings.top_p = state['top_p']
generator.settings.top_k = state['top_k']
generator.settings.typical = state['typical_p']
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
if state['ban_eos_token']:
generator.disallow_tokens([self.tokenizer.eos_token_id])
text = generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
return text
def generate_with_streaming(self, prompt, state, callback=None):
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
generator.settings.temperature = state['temperature']
generator.settings.top_p = state['top_p']
generator.settings.top_k = state['top_k']
generator.settings.typical = state['typical_p']
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
if state['ban_eos_token']:
generator.disallow_tokens([self.tokenizer.eos_token_id])
generator.end_beam_search()
ids = generator.tokenizer.encode(prompt)
generator.gen_begin(ids)
initial_len = generator.sequence[0].shape[0]
for i in range(state['max_new_tokens']):
token = generator.gen_single_token()
yield (generator.tokenizer.decode(generator.sequence[0][initial_len:]))
if token.item() == generator.tokenizer.eos_token_id:
break
def encode(self, string, **kwargs):
return self.tokenizer.encode(string)

View File

@ -59,18 +59,18 @@ class LlamaCppModel:
return self.model.tokenize(string) return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None): def generate(self, prompt, state, callback=None):
context = context if type(context) is str else context.decode() prompt = prompt if type(prompt) is str else prompt.decode()
completion_chunks = self.model.create_completion( completion_chunks = self.model.create_completion(
prompt=context, prompt=prompt,
max_tokens=token_count, max_tokens=state['max_new_tokens'],
temperature=temperature, temperature=state['temperature'],
top_p=top_p, top_p=state['top_p'],
top_k=top_k, top_k=state['top_k'],
repeat_penalty=repetition_penalty, repeat_penalty=state['repetition_penalty'],
mirostat_mode=int(mirostat_mode), mirostat_mode=int(state['mirostat_mode']),
mirostat_tau=mirostat_tau, mirostat_tau=state['mirostat_tau'],
mirostat_eta=mirostat_eta, mirostat_eta=state['mirostat_eta'],
stream=True stream=True
) )
@ -83,8 +83,8 @@ class LlamaCppModel:
return output return output
def generate_with_streaming(self, **kwargs): def generate_with_streaming(self, *args, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator: with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
reply = '' reply = ''
for token in generator: for token in generator:
reply += token reply += token

View File

@ -52,6 +52,9 @@ loaders_and_params = {
'trust_remote_code', 'trust_remote_code',
'transformers_info' 'transformers_info'
], ],
'ExLlama' : [
'exllama_info',
]
} }

View File

@ -48,7 +48,8 @@ def load_model(model_name, loader=None):
'GPTQ-for-LLaMa': GPTQ_loader, 'GPTQ-for-LLaMa': GPTQ_loader,
'llama.cpp': llamacpp_loader, 'llama.cpp': llamacpp_loader,
'FlexGen': flexgen_loader, 'FlexGen': flexgen_loader,
'RWKV': RWKV_loader 'RWKV': RWKV_loader,
'ExLlama': ExLlama_loader
} }
if loader is None: if loader is None:
@ -270,6 +271,13 @@ def AutoGPTQ_loader(model_name):
return modules.AutoGPTQ_loader.load_quantized(model_name) return modules.AutoGPTQ_loader.load_quantized(model_name)
def ExLlama_loader(model_name):
from modules.exllama import ExllamaModel
model, tokenizer = ExllamaModel.from_pretrained(model_name)
return model, tokenizer
def get_max_memory_dict(): def get_max_memory_dict():
max_memory = {} max_memory = {}
if shared.args.gpu_memory: if shared.args.gpu_memory:

View File

@ -94,7 +94,7 @@ def apply_model_settings_to_state(model, state):
loader = 'AutoGPTQ' loader = 'AutoGPTQ'
# If the user is using an alternative GPTQ loader, let them keep using it # If the user is using an alternative GPTQ loader, let them keep using it
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'exllama']): if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama']):
state['loader'] = loader state['loader'] = loader
for k in model_settings: for k in model_settings:

View File

@ -97,7 +97,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
# Model loader # Model loader
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen') parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen')
# Accelerate/transformers # Accelerate/transformers
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
@ -212,6 +212,8 @@ def fix_loader_name(name):
return 'AutoGPTQ' return 'AutoGPTQ'
elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']: elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']:
return 'GPTQ-for-LLaMa' return 'GPTQ-for-LLaMa'
elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']:
return 'ExLlama'
if args.loader is not None: if args.loader is not None:

View File

@ -51,7 +51,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None: if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:] input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel'] or shared.args.cpu: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
return input_ids return input_ids
elif shared.args.flexgen: elif shared.args.flexgen:
return input_ids.numpy() return input_ids.numpy()
@ -157,7 +157,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
yield '' yield ''
return return
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']: if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
generate_func = generate_reply_custom generate_func = generate_reply_custom
elif shared.args.flexgen: elif shared.args.flexgen:
generate_func = generate_reply_flexgen generate_func = generate_reply_flexgen
@ -283,13 +283,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False): def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
seed = set_manual_seed(state['seed']) seed = set_manual_seed(state['seed'])
generate_params = {'token_count': state['max_new_tokens']}
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
if shared.model.__class__.__name__ in ['LlamaCppModel']:
for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
generate_params[k] = state[k]
t0 = time.time() t0 = time.time()
reply = '' reply = ''
@ -298,13 +291,13 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
yield '' yield ''
if not state['stream']: if not state['stream']:
reply = shared.model.generate(context=question, **generate_params) reply = shared.model.generate(question, state)
if not is_chat: if not is_chat:
reply = apply_extensions('output', reply) reply = apply_extensions('output', reply)
yield reply yield reply
else: else:
for reply in shared.model.generate_with_streaming(context=question, **generate_params): for reply in shared.model.generate_with_streaming(question, state):
if not is_chat: if not is_chat:
reply = apply_extensions('output', reply) reply = apply_extensions('output', reply)

View File

@ -77,7 +77,10 @@ def load_model_wrapper(selected_model, loader, autoload=False):
else: else:
yield f"Failed to load {selected_model}." yield f"Failed to load {selected_model}."
except: except:
yield traceback.format_exc() exc = traceback.format_exc()
logger.error('Failed to load the model.')
print(exc)
yield exc
def load_lora_wrapper(selected_loras): def load_lora_wrapper(selected_loras):
@ -193,7 +196,7 @@ def create_model_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp"], value=None) shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "llama.cpp"], value=None)
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -213,6 +216,7 @@ def create_model_menus():
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None") shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0) shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.') shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.')
shared.gradio['exllama_info'] = gr.Markdown('ExLlama has to be installed manually. See the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama')
with gr.Column(): with gr.Column():
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton) shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)