From 345b6dee8c65b0979812a9051864f9ae0e87d25c Mon Sep 17 00:00:00 2001 From: Ayanami Rei Date: Mon, 13 Mar 2023 19:59:57 +0300 Subject: [PATCH] refactor quant models loader and add support of OPT --- .../{quantized_LLaMA.py => quant_loader.py} | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) rename modules/{quantized_LLaMA.py => quant_loader.py} (61%) diff --git a/modules/quantized_LLaMA.py b/modules/quant_loader.py similarity index 61% rename from modules/quantized_LLaMA.py rename to modules/quant_loader.py index e9352f90..8bf505a6 100644 --- a/modules/quantized_LLaMA.py +++ b/modules/quant_loader.py @@ -7,28 +7,20 @@ import torch import modules.shared as shared sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) -from llama import load_quant # 4-bit LLaMA -def load_quantized_LLaMA(model_name): - if shared.args.load_in_4bit: - bits = 4 +def load_quant(model_name, model_type): + if model_type == 'llama': + from llama import load_quant + elif model_type == 'opt': + from opt import load_quant else: - bits = shared.args.gptq_bits + print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") + exit() path_to_model = Path(f'models/{model_name}') - pt_model = '' - if path_to_model.name.lower().startswith('llama-7b'): - pt_model = f'llama-7b-{bits}bit.pt' - elif path_to_model.name.lower().startswith('llama-13b'): - pt_model = f'llama-13b-{bits}bit.pt' - elif path_to_model.name.lower().startswith('llama-30b'): - pt_model = f'llama-30b-{bits}bit.pt' - elif path_to_model.name.lower().startswith('llama-65b'): - pt_model = f'llama-65b-{bits}bit.pt' - else: - pt_model = f'{model_name}-{bits}bit.pt' + pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt' # Try to find the .pt both in models/ and in the subfolder pt_path = None @@ -40,7 +32,7 @@ def load_quantized_LLaMA(model_name): print(f"Could not find {pt_model}, exiting...") exit() - model = load_quant(path_to_model, str(pt_path), bits) + model = load_quant(path_to_model, str(pt_path), shared.args.gptq_bits) # Multiple GPUs or GPU+CPU if shared.args.gpu_memory: