mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-24 13:28:59 +01:00
Broaden GPTQ-for-LLaMA branch support (#820)
This commit is contained in:
parent
8cd899515e
commit
39f3fec913
@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
config = AutoConfig.from_pretrained(model)
|
config = AutoConfig.from_pretrained(model)
|
||||||
def noop(*args, **kwargs):
|
def noop(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
torch.nn.init.kaiming_uniform_ = noop
|
torch.nn.init.kaiming_uniform_ = noop
|
||||||
torch.nn.init.uniform_ = noop
|
torch.nn.init.uniform_ = noop
|
||||||
torch.nn.init.normal_ = noop
|
torch.nn.init.normal_ = noop
|
||||||
|
|
||||||
torch.set_default_dtype(torch.half)
|
torch.set_default_dtype(torch.half)
|
||||||
transformers.modeling_utils._init_weights = False
|
transformers.modeling_utils._init_weights = False
|
||||||
@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
for name in exclude_layers:
|
for name in exclude_layers:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
|
|
||||||
|
gptq_args = inspect.getfullargspec(make_quant).args
|
||||||
|
|
||||||
|
make_quant_kwargs = {
|
||||||
|
'module': model,
|
||||||
|
'names': layers,
|
||||||
|
'bits': wbits,
|
||||||
|
}
|
||||||
|
if 'groupsize' in gptq_args:
|
||||||
|
make_quant_kwargs['groupsize'] = groupsize
|
||||||
|
if 'faster' in gptq_args:
|
||||||
|
make_quant_kwargs['faster'] = faster_kernel
|
||||||
|
if 'kernel_switch_threshold' in gptq_args:
|
||||||
|
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||||
|
|
||||||
|
make_quant(**make_quant_kwargs)
|
||||||
|
|
||||||
del layers
|
del layers
|
||||||
|
|
||||||
print('Loading model ...')
|
print('Loading model ...')
|
||||||
if checkpoint.endswith('.safetensors'):
|
if checkpoint.endswith('.safetensors'):
|
||||||
from safetensors.torch import load_file as safe_load
|
from safetensors.torch import load_file as safe_load
|
||||||
model.load_state_dict(safe_load(checkpoint))
|
model.load_state_dict(safe_load(checkpoint), strict = False)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(checkpoint))
|
model.load_state_dict(torch.load(checkpoint), strict = False)
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
print('Done.')
|
print('Done.')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user