Generalized load_quantized

This commit is contained in:
Maya Eary 2023-03-28 20:38:55 +03:00
parent 46f6536fae
commit c8207d474f

View File

@ -4,13 +4,48 @@ from pathlib import Path
import accelerate import accelerate
import torch import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
import modules.shared as shared import modules.shared as shared
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama
import llama_inference_offload import llama_inference_offload
import opt from quant import make_quant
from modelutils import find_layers
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head']):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel)
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
def load_quantized(model_name): def load_quantized(model_name):
@ -20,6 +55,8 @@ def load_quantized(model_name):
model_type = 'llama' model_type = 'llama'
elif model_name.lower().startswith(('opt', 'galactica')): elif model_name.lower().startswith(('opt', 'galactica')):
model_type = 'opt' model_type = 'opt'
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
model_type = 'gptj'
else: else:
print("Can't determine model type from model name. Please specify it manually using --model_type " print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument") "argument")
@ -27,15 +64,12 @@ def load_quantized(model_name):
else: else:
model_type = shared.args.model_type.lower() model_type = shared.args.model_type.lower()
if model_type == 'llama': if model_type == 'llama' and shared.args.pre_layer:
if not shared.args.pre_layer: oad_quant = llama_inference_offload.load_quant
load_quant = llama.load_quant elif model_type in ('llama', 'opt', 'gptj'):
else: load_quant = _load_quant
load_quant = llama_inference_offload.load_quant
elif model_type == 'opt':
load_quant = opt.load_quant
else: else:
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit() exit()
# Now we are going to try to locate the quantized model file. # Now we are going to try to locate the quantized model file.