From 39f3fec913e8ca4ad9e11bcb94ceee76f4e07d11 Mon Sep 17 00:00:00 2001 From: EyeDeck Date: Thu, 6 Apr 2023 11:16:48 -0400 Subject: [PATCH] Broaden GPTQ-for-LLaMA branch support (#820) --- modules/GPTQ_loader.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index abfa33af..572947ad 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -1,3 +1,4 @@ +import inspect import re import sys from pathlib import Path @@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc config = AutoConfig.from_pretrained(model) def noop(*args, **kwargs): pass - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop torch.set_default_dtype(torch.half) transformers.modeling_utils._init_weights = False @@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc for name in exclude_layers: if name in layers: del layers[name] - make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold) + + gptq_args = inspect.getfullargspec(make_quant).args + + make_quant_kwargs = { + 'module': model, + 'names': layers, + 'bits': wbits, + } + if 'groupsize' in gptq_args: + make_quant_kwargs['groupsize'] = groupsize + if 'faster' in gptq_args: + make_quant_kwargs['faster'] = faster_kernel + if 'kernel_switch_threshold' in gptq_args: + make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold + + make_quant(**make_quant_kwargs) del layers - + print('Loading model ...') if checkpoint.endswith('.safetensors'): from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint)) + model.load_state_dict(safe_load(checkpoint), strict = False) else: - model.load_state_dict(torch.load(checkpoint)) + model.load_state_dict(torch.load(checkpoint), strict = False) model.seqlen = 2048 print('Done.')