mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Broaden GPTQ-for-LLaMA branch support (#820)
This commit is contained in:
parent
8cd899515e
commit
39f3fec913
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
torch.nn.init.kaiming_uniform_ = noop
|
||||
torch.nn.init.uniform_ = noop
|
||||
torch.nn.init.normal_ = noop
|
||||
torch.nn.init.kaiming_uniform_ = noop
|
||||
torch.nn.init.uniform_ = noop
|
||||
torch.nn.init.normal_ = noop
|
||||
|
||||
torch.set_default_dtype(torch.half)
|
||||
transformers.modeling_utils._init_weights = False
|
||||
@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
||||
for name in exclude_layers:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
|
||||
|
||||
gptq_args = inspect.getfullargspec(make_quant).args
|
||||
|
||||
make_quant_kwargs = {
|
||||
'module': model,
|
||||
'names': layers,
|
||||
'bits': wbits,
|
||||
}
|
||||
if 'groupsize' in gptq_args:
|
||||
make_quant_kwargs['groupsize'] = groupsize
|
||||
if 'faster' in gptq_args:
|
||||
make_quant_kwargs['faster'] = faster_kernel
|
||||
if 'kernel_switch_threshold' in gptq_args:
|
||||
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||
|
||||
make_quant(**make_quant_kwargs)
|
||||
|
||||
del layers
|
||||
|
||||
|
||||
print('Loading model ...')
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
model.load_state_dict(safe_load(checkpoint))
|
||||
model.load_state_dict(safe_load(checkpoint), strict = False)
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
model.load_state_dict(torch.load(checkpoint), strict = False)
|
||||
model.seqlen = 2048
|
||||
print('Done.')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user