From edbc61139ff5a0ccb2c41a3d8446b231fd31ac5e Mon Sep 17 00:00:00 2001 From: Ayanami Rei Date: Mon, 13 Mar 2023 20:00:38 +0300 Subject: [PATCH] use new quant loader --- modules/models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/models.py b/modules/models.py index 7d094ed5..31696795 100644 --- a/modules/models.py +++ b/modules/models.py @@ -1,6 +1,5 @@ import json import os -import sys import time import zipfile from pathlib import Path @@ -35,6 +34,7 @@ if shared.args.deepspeed: ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir) dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration + def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() @@ -42,7 +42,7 @@ def load_model(model_name): shared.is_RWKV = model_name.lower().startswith('rwkv-') # Default settings - if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): + if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: @@ -87,11 +87,11 @@ def load_model(model_name): return model, tokenizer - # 4-bit LLaMA - elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit: - from modules.quantized_LLaMA import load_quantized_LLaMA + # Quantized model + elif shared.args.gptq_bits > 0: + from modules.quant_loader import load_quant - model = load_quantized_LLaMA(model_name) + model = load_quant(model_name, shared.args.gptq_model_type) # Custom else: