From 3bbf6c601d2f166d5144eaabff4ed8fea384cf5d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 15 Dec 2023 06:46:13 -0800 Subject: [PATCH] AutoGPTQ: Add --disable_exllamav2 flag (Mixtral CPU offloading needs this) --- README.md | 1 + modules/AutoGPTQ_loader.py | 1 + modules/loaders.py | 2 ++ modules/models.py | 13 +++++++++---- modules/shared.py | 1 + modules/ui.py | 1 + modules/ui_model_menu.py | 1 + 7 files changed, 16 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d089d42c..d35ebe04 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,7 @@ List of command-line flags | `--no_use_cuda_fp16` | This can make models faster on some systems. | | `--desc_act` | For models that don't have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig. | | `--disable_exllama` | Disable ExLlama kernel, which can improve inference speed on some systems. | +| `--disable_exllamav2` | Disable ExLlamav2 kernel. | #### GPTQ-for-LLaMa diff --git a/modules/AutoGPTQ_loader.py b/modules/AutoGPTQ_loader.py index f33803e8..514a6ee5 100644 --- a/modules/AutoGPTQ_loader.py +++ b/modules/AutoGPTQ_loader.py @@ -52,6 +52,7 @@ def load_quantized(model_name): 'quantize_config': quantize_config, 'use_cuda_fp16': not shared.args.no_use_cuda_fp16, 'disable_exllama': shared.args.disable_exllama, + 'disable_exllamav2': shared.args.disable_exllamav2, } logger.info(f"The AutoGPTQ params are: {params}") diff --git a/modules/loaders.py b/modules/loaders.py index d1f8343f..c7e7653e 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -25,6 +25,7 @@ loaders_and_params = OrderedDict({ 'rope_freq_base', 'compress_pos_emb', 'disable_exllama', + 'disable_exllamav2', 'transformers_info' ], 'llama.cpp': [ @@ -94,6 +95,7 @@ loaders_and_params = OrderedDict({ 'groupsize', 'desc_act', 'disable_exllama', + 'disable_exllamav2', 'gpu_memory', 'cpu_memory', 'cpu', diff --git a/modules/models.py b/modules/models.py index f77fc941..49e5f818 100644 --- a/modules/models.py +++ b/modules/models.py @@ -156,7 +156,7 @@ def huggingface_loader(model_name): LoaderClass = AutoModelForCausalLM # Load the model in simple 16-bit mode by default - if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]): + if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama, shared.args.disable_exllamav2]): model = LoaderClass.from_pretrained(path_to_model, **params) if torch.backends.mps.is_available(): device = torch.device('mps') @@ -221,11 +221,16 @@ def huggingface_loader(model_name): if shared.args.disk: params['offload_folder'] = shared.args.disk_cache_dir - if shared.args.disable_exllama: + if shared.args.disable_exllama or shared.args.disable_exllamav2: try: - gptq_config = GPTQConfig(bits=config.quantization_config.get('bits', 4), disable_exllama=True) + gptq_config = GPTQConfig( + bits=config.quantization_config.get('bits', 4), + disable_exllama=shared.args.disable_exllama, + disable_exllamav2=shared.args.disable_exllamav2, + ) + params['quantization_config'] = gptq_config - logger.info('Loading with ExLlama kernel disabled.') + logger.info(f'Loading with disable_exllama={shared.args.disable_exllama} and disable_exllamav2={shared.args.disable_exllamav2}.') except: exc = traceback.format_exc() logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?') diff --git a/modules/shared.py b/modules/shared.py index b0888935..adebe62d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -133,6 +133,7 @@ parser.add_argument('--no_inject_fused_mlp', action='store_true', help='Triton m parser.add_argument('--no_use_cuda_fp16', action='store_true', help='This can make models faster on some systems.') parser.add_argument('--desc_act', action='store_true', help='For models that do not have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.') parser.add_argument('--disable_exllama', action='store_true', help='Disable ExLlama kernel, which can improve inference speed on some systems.') +parser.add_argument('--disable_exllamav2', action='store_true', help='Disable ExLlamav2 kernel.') # GPTQ-for-LLaMa parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') diff --git a/modules/ui.py b/modules/ui.py index 45849fe3..8bfc9491 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -70,6 +70,7 @@ def list_model_elements(): 'no_inject_fused_mlp', 'no_use_cuda_fp16', 'disable_exllama', + 'disable_exllamav2', 'cfg_cache', 'no_flash_attn', 'cache_8bit', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 7b6767dc..7242d117 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -125,6 +125,7 @@ def create_ui(): shared.gradio['logits_all'] = gr.Checkbox(label="logits_all", value=shared.args.logits_all, info='Needs to be set for perplexity evaluation to work. Otherwise, ignore it, as it makes prompt processing slower.') shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.') shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.') + shared.gradio['disable_exllamav2'] = gr.Checkbox(label="disable_exllamav2", value=shared.args.disable_exllamav2, info='Disable ExLlamav2 kernel.') shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')