Broaden GPTQ-for-LLaMA branch support (#820)

This commit is contained in:
EyeDeck 2023-04-06 11:16:48 -04:00 committed by GitHub
parent 8cd899515e
commit 39f3fec913
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,4 @@
import inspect
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
config = AutoConfig.from_pretrained(model) config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
torch.nn.init.kaiming_uniform_ = noop torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half) torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False transformers.modeling_utils._init_weights = False
@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
for name in exclude_layers: for name in exclude_layers:
if name in layers: if name in layers:
del layers[name] del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
gptq_args = inspect.getfullargspec(make_quant).args
make_quant_kwargs = {
'module': model,
'names': layers,
'bits': wbits,
}
if 'groupsize' in gptq_args:
make_quant_kwargs['groupsize'] = groupsize
if 'faster' in gptq_args:
make_quant_kwargs['faster'] = faster_kernel
if 'kernel_switch_threshold' in gptq_args:
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
make_quant(**make_quant_kwargs)
del layers del layers
print('Loading model ...') print('Loading model ...')
if checkpoint.endswith('.safetensors'): if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint)) model.load_state_dict(safe_load(checkpoint), strict = False)
else: else:
model.load_state_dict(torch.load(checkpoint)) model.load_state_dict(torch.load(checkpoint), strict = False)
model.seqlen = 2048 model.seqlen = 2048
print('Done.') print('Done.')