mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Generalized load_quantized
This commit is contained in:
parent
46f6536fae
commit
c8207d474f
@ -4,13 +4,48 @@ from pathlib import Path
|
|||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
||||||
import llama
|
|
||||||
import llama_inference_offload
|
import llama_inference_offload
|
||||||
import opt
|
from quant import make_quant
|
||||||
|
from modelutils import find_layers
|
||||||
|
|
||||||
|
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head']):
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
def noop(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
torch.nn.init.kaiming_uniform_ = noop
|
||||||
|
torch.nn.init.uniform_ = noop
|
||||||
|
torch.nn.init.normal_ = noop
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.half)
|
||||||
|
transformers.modeling_utils._init_weights = False
|
||||||
|
torch.set_default_dtype(torch.half)
|
||||||
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
model = model.eval()
|
||||||
|
layers = find_layers(model)
|
||||||
|
for name in exclude_layers:
|
||||||
|
if name in layers:
|
||||||
|
del layers[name]
|
||||||
|
make_quant(model, layers, wbits, groupsize, faster=faster_kernel)
|
||||||
|
|
||||||
|
del layers
|
||||||
|
|
||||||
|
print('Loading model ...')
|
||||||
|
if checkpoint.endswith('.safetensors'):
|
||||||
|
from safetensors.torch import load_file as safe_load
|
||||||
|
model.load_state_dict(safe_load(checkpoint))
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch.load(checkpoint))
|
||||||
|
model.seqlen = 2048
|
||||||
|
print('Done.')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_quantized(model_name):
|
def load_quantized(model_name):
|
||||||
@ -20,6 +55,8 @@ def load_quantized(model_name):
|
|||||||
model_type = 'llama'
|
model_type = 'llama'
|
||||||
elif model_name.lower().startswith(('opt', 'galactica')):
|
elif model_name.lower().startswith(('opt', 'galactica')):
|
||||||
model_type = 'opt'
|
model_type = 'opt'
|
||||||
|
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
|
||||||
|
model_type = 'gptj'
|
||||||
else:
|
else:
|
||||||
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
||||||
"argument")
|
"argument")
|
||||||
@ -27,15 +64,12 @@ def load_quantized(model_name):
|
|||||||
else:
|
else:
|
||||||
model_type = shared.args.model_type.lower()
|
model_type = shared.args.model_type.lower()
|
||||||
|
|
||||||
if model_type == 'llama':
|
if model_type == 'llama' and shared.args.pre_layer:
|
||||||
if not shared.args.pre_layer:
|
oad_quant = llama_inference_offload.load_quant
|
||||||
load_quant = llama.load_quant
|
elif model_type in ('llama', 'opt', 'gptj'):
|
||||||
else:
|
load_quant = _load_quant
|
||||||
load_quant = llama_inference_offload.load_quant
|
|
||||||
elif model_type == 'opt':
|
|
||||||
load_quant = opt.load_quant
|
|
||||||
else:
|
else:
|
||||||
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
|
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
# Now we are going to try to locate the quantized model file.
|
# Now we are going to try to locate the quantized model file.
|
||||||
|
Loading…
Reference in New Issue
Block a user