From 1a8151a2b655fadbb88a48de17227d6303dce339 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 17 May 2023 11:12:12 -0300 Subject: [PATCH] Add AutoGPTQ support (basic) (#2132) --- modules/AutoGPTQ_loader.py | 41 ++++++++++++++++++++++++++++++++++++++ modules/models.py | 13 ++++++++++-- modules/shared.py | 4 ++++ 3 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 modules/AutoGPTQ_loader.py diff --git a/modules/AutoGPTQ_loader.py b/modules/AutoGPTQ_loader.py new file mode 100644 index 00000000..46a8e736 --- /dev/null +++ b/modules/AutoGPTQ_loader.py @@ -0,0 +1,41 @@ +import logging +from pathlib import Path + +from auto_gptq import AutoGPTQForCausalLM + +import modules.shared as shared +from modules.models import get_max_memory_dict + + +def load_quantized(model_name): + path_to_model = Path(f'{shared.args.model_dir}/{model_name}') + pt_path = None + use_safetensors = False + + # Find the model checkpoint + found_pts = list(path_to_model.glob("*.pt")) + found_safetensors = list(path_to_model.glob("*.safetensors")) + if len(found_safetensors) > 0: + if len(found_pts) > 1: + logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.') + + use_safetensors = True + pt_path = found_safetensors[-1] + elif len(found_pts) > 0: + if len(found_pts) > 1: + logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.') + + pt_path = found_pts[-1] + + # Define the params for AutoGPTQForCausalLM.from_quantized + params = { + 'model_basename': pt_path.stem, + 'device': "cuda:0" if not shared.args.cpu else "cpu", + 'use_triton': shared.args.triton, + 'use_safetensors': use_safetensors, + 'max_memory': get_max_memory_dict() + } + + logging.warning(f"The AutoGPTQ params are: {params}") + model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params) + return model diff --git a/modules/models.py b/modules/models.py index 5696b624..41d4f479 100644 --- a/modules/models.py +++ b/modules/models.py @@ -72,7 +72,10 @@ def load_model(model_name): shared.model_type = find_model_type(model_name) if shared.args.wbits > 0: - load_func = GPTQ_loader + if shared.args.autogptq: + load_func = AutoGPTQ_loader + else: + load_func = GPTQ_loader elif shared.model_type == 'llamacpp': load_func = llamacpp_loader elif shared.model_type == 'rwkv': @@ -261,6 +264,12 @@ def GPTQ_loader(model_name): return model +def AutoGPTQ_loader(model_name): + from modules.AutoGPTQ_loader import load_quantized + + return load_quantized(model_name) + + def get_max_memory_dict(): max_memory = {} if shared.args.gpu_memory: @@ -283,7 +292,7 @@ def get_max_memory_dict(): logging.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.") max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'} - return max_memory + return max_memory if len(max_memory) > 0 else None def clear_torch_cache(): diff --git a/modules/shared.py b/modules/shared.py index 3e1241c5..7f945366 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -137,6 +137,10 @@ parser.add_argument('--quant_attn', action='store_true', help='(triton) Enable q parser.add_argument('--warmup_autotune', action='store_true', help='(triton) Enable warmup autotune.') parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fused mlp.') +# AutoGPTQ +parser.add_argument('--autogptq', action='store_true', help='Use AutoGPTQ for loading quantized models instead of the internal GPTQ loader.') +parser.add_argument('--triton', action='store_true', help='Use triton.') + # FlexGen parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.') parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')