From 9f40032d32165773337e6a6c60de39d3f3beb77d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 16 Jun 2023 20:35:38 -0300 Subject: [PATCH] Add ExLlama support (#2444) --- README.md | 4 +- docs/ExLlama.md | 16 ++++++++ modules/RWKV.py | 29 +++++++------- modules/callbacks.py | 5 ++- modules/exllama.py | 81 ++++++++++++++++++++++++++++++++++++++ modules/llamacpp_model.py | 26 ++++++------ modules/loaders.py | 3 ++ modules/models.py | 10 ++++- modules/models_settings.py | 2 +- modules/shared.py | 4 +- modules/text_generation.py | 15 ++----- server.py | 8 +++- 12 files changed, 156 insertions(+), 47 deletions(-) create mode 100644 docs/ExLlama.md create mode 100644 modules/exllama.py diff --git a/README.md b/README.md index f998d89f..333df158 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. ## Features * 3 interface modes: default, notebook, and chat -* Multiple model backends: tranformers, llama.cpp, AutoGPTQ, GPTQ-for-LLaMa, RWKV, FlexGen +* Multiple model backends: tranformers, llama.cpp, AutoGPTQ, GPTQ-for-LLaMa, ExLlama, RWKV, FlexGen * Dropdown menu for quickly switching between different models * LoRA: load and unload LoRAs on the fly, load multiple LoRAs at the same time, train a new LoRA * Precise instruction templates for chat mode, including Alpaca, Vicuna, Open Assistant, Dolly, Koala, ChatGLM, MOSS, RWKV-Raven, Galactica, StableLM, WizardLM, Baize, Ziya, Chinese-Vicuna, MPT, INCITE, Wizard Mega, KoAlpaca, Vigogne, Bactrian, h2o, and OpenBuddy @@ -215,7 +215,7 @@ Optionally, you can use the following command-line flags: | Flag | Description | |--------------------------------------------|-------------| -| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen | +| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen | #### Accelerate/transformers diff --git a/docs/ExLlama.md b/docs/ExLlama.md new file mode 100644 index 00000000..a0968927 --- /dev/null +++ b/docs/ExLlama.md @@ -0,0 +1,16 @@ +# ExLlama + +## About + +ExLlama is an extremely optimized GPTQ backend for LLaMA models. It features much lower VRAM usage and much higher speeds due to not relying on unoptimized transformers code. + +# Installation: + +1) Clone the ExLlama repository into your `repositories` folder: + +``` +cd repositories +git clone https://github.com/turboderp/exllama +``` + +2) Follow the remaining set up instructions in the official README: https://github.com/turboderp/exllama#exllama diff --git a/modules/RWKV.py b/modules/RWKV.py index 08a4bd54..35d69986 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -38,31 +38,31 @@ class RWKVModel: result.cached_output_logits = None return result - def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=None, token_stop=None, callback=None): + def generate(self, prompt, state, callback=None): args = PIPELINE_ARGS( - temperature=temperature, - top_p=top_p, - top_k=top_k, - alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3) - alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3) - token_ban=token_ban or [0], # ban the generation of some tokens - token_stop=token_stop or [] + temperature=state['temperature'], + top_p=state['top_p'], + top_k=state['top_k'], + alpha_frequency=0.1, # Frequency Penalty (as in GPT-3) + alpha_presence=0.1, # Presence Penalty (as in GPT-3) + token_ban=[0], # ban the generation of some tokens + token_stop=[] ) if self.cached_context != "": - if context.startswith(self.cached_context): - context = context[len(self.cached_context):] + if prompt.startswith(self.cached_context): + prompt = prompt[len(self.cached_context):] else: self.cached_context = "" self.cached_model_state = None self.cached_output_logits = None - # out = self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) - out = self.generate_from_cached_state(context, token_count=token_count, args=args, callback=callback) + # out = self.pipeline.generate(prompt, token_count=state['max_new_tokens'], args=args, callback=callback) + out = self.generate_from_cached_state(prompt, token_count=state['max_new_tokens'], args=args, callback=callback) return out - def generate_with_streaming(self, **kwargs): - with Iteratorize(self.generate, kwargs, callback=None) as generator: + def generate_with_streaming(self, *args, **kwargs): + with Iteratorize(self.generate, args, kwargs, callback=None) as generator: reply = '' for token in generator: reply += token @@ -81,6 +81,7 @@ class RWKVModel: if ctx == "": out = self.cached_output_logits + token = None for i in range(token_count): # forward tokens = self.pipeline.encode(ctx) if i == 0 else [token] diff --git a/modules/callbacks.py b/modules/callbacks.py index 5996ba4e..fb92e18a 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -55,11 +55,12 @@ class Iteratorize: Adapted from: https://stackoverflow.com/a/9969000 """ - def __init__(self, func, kwargs=None, callback=None): + def __init__(self, func, args=None, kwargs=None, callback=None): self.mfunc = func self.c_callback = callback self.q = Queue() self.sentinel = object() + self.args = args or [] self.kwargs = kwargs or {} self.stop_now = False @@ -70,7 +71,7 @@ class Iteratorize: def gentask(): try: - ret = self.mfunc(callback=_callback, **self.kwargs) + ret = self.mfunc(callback=_callback, *args, **self.kwargs) except ValueError: pass except: diff --git a/modules/exllama.py b/modules/exllama.py new file mode 100644 index 00000000..11deb9b0 --- /dev/null +++ b/modules/exllama.py @@ -0,0 +1,81 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path("repositories/exllama"))) + +from modules.logging_colors import logger +from repositories.exllama.generator import ExLlamaGenerator +from repositories.exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig +from repositories.exllama.tokenizer import ExLlamaTokenizer + + +class ExllamaModel: + def __init__(self): + pass + + @classmethod + def from_pretrained(self, path_to_model): + + path_to_model = Path("models") / Path(path_to_model) + tokenizer_model_path = path_to_model / "tokenizer.model" + model_config_path = path_to_model / "config.json" + + # Find the model checkpoint + model_path = None + for ext in ['.safetensors', '.pt', '.bin']: + found = list(path_to_model.glob(f"*{ext}")) + if len(found) > 0: + if len(found) > 1: + logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.') + + model_path = found[-1] + break + + config = ExLlamaConfig(str(model_config_path)) + config.model_path = str(model_path) + model = ExLlama(config) + tokenizer = ExLlamaTokenizer(str(tokenizer_model_path)) + cache = ExLlamaCache(model) + + result = self() + result.config = config + result.model = model + result.cache = cache + result.tokenizer = tokenizer + return result, result + + def generate(self, prompt, state, callback=None): + generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) + generator.settings.temperature = state['temperature'] + generator.settings.top_p = state['top_p'] + generator.settings.top_k = state['top_k'] + generator.settings.typical = state['typical_p'] + generator.settings.token_repetition_penalty_max = state['repetition_penalty'] + if state['ban_eos_token']: + generator.disallow_tokens([self.tokenizer.eos_token_id]) + + text = generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens']) + return text + + def generate_with_streaming(self, prompt, state, callback=None): + generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) + generator.settings.temperature = state['temperature'] + generator.settings.top_p = state['top_p'] + generator.settings.top_k = state['top_k'] + generator.settings.typical = state['typical_p'] + generator.settings.token_repetition_penalty_max = state['repetition_penalty'] + if state['ban_eos_token']: + generator.disallow_tokens([self.tokenizer.eos_token_id]) + + generator.end_beam_search() + ids = generator.tokenizer.encode(prompt) + generator.gen_begin(ids) + initial_len = generator.sequence[0].shape[0] + for i in range(state['max_new_tokens']): + token = generator.gen_single_token() + yield (generator.tokenizer.decode(generator.sequence[0][initial_len:])) + if token.item() == generator.tokenizer.eos_token_id: + break + + def encode(self, string, **kwargs): + return self.tokenizer.encode(string) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 4f2de155..578d2b4b 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -59,18 +59,18 @@ class LlamaCppModel: return self.model.tokenize(string) - def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None): - context = context if type(context) is str else context.decode() + def generate(self, prompt, state, callback=None): + prompt = prompt if type(prompt) is str else prompt.decode() completion_chunks = self.model.create_completion( - prompt=context, - max_tokens=token_count, - temperature=temperature, - top_p=top_p, - top_k=top_k, - repeat_penalty=repetition_penalty, - mirostat_mode=int(mirostat_mode), - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, + prompt=prompt, + max_tokens=state['max_new_tokens'], + temperature=state['temperature'], + top_p=state['top_p'], + top_k=state['top_k'], + repeat_penalty=state['repetition_penalty'], + mirostat_mode=int(state['mirostat_mode']), + mirostat_tau=state['mirostat_tau'], + mirostat_eta=state['mirostat_eta'], stream=True ) @@ -83,8 +83,8 @@ class LlamaCppModel: return output - def generate_with_streaming(self, **kwargs): - with Iteratorize(self.generate, kwargs, callback=None) as generator: + def generate_with_streaming(self, *args, **kwargs): + with Iteratorize(self.generate, args, kwargs, callback=None) as generator: reply = '' for token in generator: reply += token diff --git a/modules/loaders.py b/modules/loaders.py index 43fff5c7..87fac259 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -52,6 +52,9 @@ loaders_and_params = { 'trust_remote_code', 'transformers_info' ], + 'ExLlama' : [ + 'exllama_info', + ] } diff --git a/modules/models.py b/modules/models.py index 027d2bfe..4a4ea718 100644 --- a/modules/models.py +++ b/modules/models.py @@ -48,7 +48,8 @@ def load_model(model_name, loader=None): 'GPTQ-for-LLaMa': GPTQ_loader, 'llama.cpp': llamacpp_loader, 'FlexGen': flexgen_loader, - 'RWKV': RWKV_loader + 'RWKV': RWKV_loader, + 'ExLlama': ExLlama_loader } if loader is None: @@ -270,6 +271,13 @@ def AutoGPTQ_loader(model_name): return modules.AutoGPTQ_loader.load_quantized(model_name) +def ExLlama_loader(model_name): + from modules.exllama import ExllamaModel + + model, tokenizer = ExllamaModel.from_pretrained(model_name) + return model, tokenizer + + def get_max_memory_dict(): max_memory = {} if shared.args.gpu_memory: diff --git a/modules/models_settings.py b/modules/models_settings.py index 2132d71d..2b24bb14 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -94,7 +94,7 @@ def apply_model_settings_to_state(model, state): loader = 'AutoGPTQ' # If the user is using an alternative GPTQ loader, let them keep using it - if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'exllama']): + if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama']): state['loader'] = loader for k in model_settings: diff --git a/modules/shared.py b/modules/shared.py index c041f354..c7dba9cb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -97,7 +97,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') # Model loader -parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen') +parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen') # Accelerate/transformers parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') @@ -212,6 +212,8 @@ def fix_loader_name(name): return 'AutoGPTQ' elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']: return 'GPTQ-for-LLaMa' + elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']: + return 'ExLlama' if args.loader is not None: diff --git a/modules/text_generation.py b/modules/text_generation.py index bba2e524..7535d141 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -51,7 +51,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu: return input_ids elif shared.args.flexgen: return input_ids.numpy() @@ -157,7 +157,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c yield '' return - if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']: + if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']: generate_func = generate_reply_custom elif shared.args.flexgen: generate_func = generate_reply_flexgen @@ -283,13 +283,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False): seed = set_manual_seed(state['seed']) - generate_params = {'token_count': state['max_new_tokens']} - for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: - generate_params[k] = state[k] - - if shared.model.__class__.__name__ in ['LlamaCppModel']: - for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']: - generate_params[k] = state[k] t0 = time.time() reply = '' @@ -298,13 +291,13 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No yield '' if not state['stream']: - reply = shared.model.generate(context=question, **generate_params) + reply = shared.model.generate(question, state) if not is_chat: reply = apply_extensions('output', reply) yield reply else: - for reply in shared.model.generate_with_streaming(context=question, **generate_params): + for reply in shared.model.generate_with_streaming(question, state): if not is_chat: reply = apply_extensions('output', reply) diff --git a/server.py b/server.py index 6433330b..ed6e1f4f 100644 --- a/server.py +++ b/server.py @@ -77,7 +77,10 @@ def load_model_wrapper(selected_model, loader, autoload=False): else: yield f"Failed to load {selected_model}." except: - yield traceback.format_exc() + exc = traceback.format_exc() + logger.error('Failed to load the model.') + print(exc) + yield exc def load_lora_wrapper(selected_loras): @@ -193,7 +196,7 @@ def create_model_menus(): with gr.Row(): with gr.Column(): - shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp"], value=None) + shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "llama.cpp"], value=None) with gr.Box(): with gr.Row(): with gr.Column(): @@ -213,6 +216,7 @@ def create_model_menus(): shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None") shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0) shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.') + shared.gradio['exllama_info'] = gr.Markdown('ExLlama has to be installed manually. See the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama') with gr.Column(): shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)