Extend AutoGPTQ support for any GPTQ model (#1668)

This commit is contained in:
LaaZa 2023-06-02 07:33:55 +03:00 committed by GitHub
parent b4ad060c1f
commit 9c066601f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 14 deletions

View File

@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from auto_gptq import AutoGPTQForCausalLM from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger from modules.logging_colors import logger
@ -10,9 +10,11 @@ from modules.models import get_max_memory_dict
def load_quantized(model_name): def load_quantized(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = None pt_path = None
use_safetensors = False
# Find the model checkpoint # Find the model checkpoint
if shared.args.checkpoint:
pt_path = Path(shared.args.checkpoint)
else:
for ext in ['.safetensors', '.pt', '.bin']: for ext in ['.safetensors', '.pt', '.bin']:
found = list(path_to_model.glob(f"*{ext}")) found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0: if len(found) > 0:
@ -20,15 +22,22 @@ def load_quantized(model_name):
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.') logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
pt_path = found[-1] pt_path = found[-1]
if ext == '.safetensors':
use_safetensors = True
break break
if pt_path is None: if pt_path is None:
logger.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.") logger.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
return return
use_safetensors = pt_path.suffix == '.safetensors'
if not (path_to_model / "quantize_config.json").exists():
quantize_config = BaseQuantizeConfig(
bits=bits if (bits := shared.args.wbits) > 0 else 4,
group_size=gs if (gs := shared.args.groupsize) > 0 else -1,
desc_act=shared.args.desc_act
)
else:
quantize_config = None
# Define the params for AutoGPTQForCausalLM.from_quantized # Define the params for AutoGPTQForCausalLM.from_quantized
params = { params = {
'model_basename': pt_path.stem, 'model_basename': pt_path.stem,
@ -36,9 +45,10 @@ def load_quantized(model_name):
'use_triton': shared.args.triton, 'use_triton': shared.args.triton,
'use_safetensors': use_safetensors, 'use_safetensors': use_safetensors,
'trust_remote_code': shared.args.trust_remote_code, 'trust_remote_code': shared.args.trust_remote_code,
'max_memory': get_max_memory_dict() 'max_memory': get_max_memory_dict(),
'quantize_config': quantize_config
} }
logger.warning(f"The AutoGPTQ params are: {params}") logger.info(f"The AutoGPTQ params are: {params}")
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params) model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
return model return model

View File

@ -142,6 +142,7 @@ parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fu
# AutoGPTQ # AutoGPTQ
parser.add_argument('--autogptq', action='store_true', help='Use AutoGPTQ for loading quantized models instead of the internal GPTQ loader.') parser.add_argument('--autogptq', action='store_true', help='Use AutoGPTQ for loading quantized models instead of the internal GPTQ loader.')
parser.add_argument('--triton', action='store_true', help='Use triton.') parser.add_argument('--triton', action='store_true', help='Use triton.')
parser.add_argument('--desc_act', action='store_true', help='For models that don\'t have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.')
# FlexGen # FlexGen
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.') parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')