From 9c066601f52ae79fbacfe91c920e7914d1c456a1 Mon Sep 17 00:00:00 2001 From: LaaZa Date: Fri, 2 Jun 2023 07:33:55 +0300 Subject: [PATCH] Extend AutoGPTQ support for any GPTQ model (#1668) --- modules/AutoGPTQ_loader.py | 38 ++++++++++++++++++++++++-------------- modules/shared.py | 1 + 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/modules/AutoGPTQ_loader.py b/modules/AutoGPTQ_loader.py index f60dda70..5b87fe56 100644 --- a/modules/AutoGPTQ_loader.py +++ b/modules/AutoGPTQ_loader.py @@ -1,6 +1,6 @@ from pathlib import Path -from auto_gptq import AutoGPTQForCausalLM +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig import modules.shared as shared from modules.logging_colors import logger @@ -10,25 +10,34 @@ from modules.models import get_max_memory_dict def load_quantized(model_name): path_to_model = Path(f'{shared.args.model_dir}/{model_name}') pt_path = None - use_safetensors = False # Find the model checkpoint - for ext in ['.safetensors', '.pt', '.bin']: - found = list(path_to_model.glob(f"*{ext}")) - if len(found) > 0: - if len(found) > 1: - logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.') + if shared.args.checkpoint: + pt_path = Path(shared.args.checkpoint) + else: + for ext in ['.safetensors', '.pt', '.bin']: + found = list(path_to_model.glob(f"*{ext}")) + if len(found) > 0: + if len(found) > 1: + logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.') - pt_path = found[-1] - if ext == '.safetensors': - use_safetensors = True - - break + pt_path = found[-1] + break if pt_path is None: logger.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.") return + use_safetensors = pt_path.suffix == '.safetensors' + if not (path_to_model / "quantize_config.json").exists(): + quantize_config = BaseQuantizeConfig( + bits=bits if (bits := shared.args.wbits) > 0 else 4, + group_size=gs if (gs := shared.args.groupsize) > 0 else -1, + desc_act=shared.args.desc_act + ) + else: + quantize_config = None + # Define the params for AutoGPTQForCausalLM.from_quantized params = { 'model_basename': pt_path.stem, @@ -36,9 +45,10 @@ def load_quantized(model_name): 'use_triton': shared.args.triton, 'use_safetensors': use_safetensors, 'trust_remote_code': shared.args.trust_remote_code, - 'max_memory': get_max_memory_dict() + 'max_memory': get_max_memory_dict(), + 'quantize_config': quantize_config } - logger.warning(f"The AutoGPTQ params are: {params}") + logger.info(f"The AutoGPTQ params are: {params}") model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params) return model diff --git a/modules/shared.py b/modules/shared.py index d33f8483..a7df12e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -142,6 +142,7 @@ parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fu # AutoGPTQ parser.add_argument('--autogptq', action='store_true', help='Use AutoGPTQ for loading quantized models instead of the internal GPTQ loader.') parser.add_argument('--triton', action='store_true', help='Use triton.') +parser.add_argument('--desc_act', action='store_true', help='For models that don\'t have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.') # FlexGen parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')