Minor rewrite

This commit is contained in:
oobabooga 2023-04-05 01:21:40 -03:00
parent f3a2e0b8a9
commit 3d6cb5ed63

View File

@ -65,13 +65,11 @@ def load_quantized(model_name):
else:
model_type = shared.args.model_type.lower()
if shared.args.pre_layer:
if model_type == 'llama':
load_quant = llama_inference_offload.load_quant
else:
print("Warning: ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant
if shared.args.pre_layer and model_type == 'llama':
load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer:
print("Warning: ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")