Add warnings for when AutoGPTQ, TensorRT-LLM, or HQQ are missing

This commit is contained in:
oobabooga 2024-09-28 20:30:24 -07:00
parent 65e5864084
commit b92d7fd43e

View File

@ -70,11 +70,11 @@ def load_model(model_name, loader=None):
shared.model_name = model_name shared.model_name = model_name
load_func_map = { load_func_map = {
'Transformers': huggingface_loader, 'Transformers': huggingface_loader,
'AutoGPTQ': AutoGPTQ_loader,
'llama.cpp': llamacpp_loader, 'llama.cpp': llamacpp_loader,
'llamacpp_HF': llamacpp_HF_loader, 'llamacpp_HF': llamacpp_HF_loader,
'ExLlamav2': ExLlamav2_loader, 'ExLlamav2': ExLlamav2_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader, 'ExLlamav2_HF': ExLlamav2_HF_loader,
'AutoGPTQ': AutoGPTQ_loader,
'HQQ': HQQ_loader, 'HQQ': HQQ_loader,
'TensorRT-LLM': TensorRT_LLM_loader, 'TensorRT-LLM': TensorRT_LLM_loader,
} }
@ -302,12 +302,6 @@ def llamacpp_HF_loader(model_name):
return model return model
def AutoGPTQ_loader(model_name):
import modules.AutoGPTQ_loader
return modules.AutoGPTQ_loader.load_quantized(model_name)
def ExLlamav2_loader(model_name): def ExLlamav2_loader(model_name):
from modules.exllamav2 import Exllamav2Model from modules.exllamav2 import Exllamav2Model
@ -321,9 +315,21 @@ def ExLlamav2_HF_loader(model_name):
return Exllamav2HF.from_pretrained(model_name) return Exllamav2HF.from_pretrained(model_name)
def AutoGPTQ_loader(model_name):
try:
import modules.AutoGPTQ_loader
except ModuleNotFoundError:
raise ModuleNotFoundError("Failed to import 'autogptq'. Please install it manually following the instructions in the AutoGPTQ GitHub repository.")
return modules.AutoGPTQ_loader.load_quantized(model_name)
def HQQ_loader(model_name): def HQQ_loader(model_name):
from hqq.core.quantize import HQQBackend, HQQLinear try:
from hqq.models.hf.base import AutoHQQHFModel from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.models.hf.base import AutoHQQHFModel
except ModuleNotFoundError:
raise ModuleNotFoundError("Failed to import 'hqq'. Please install it manually following the instructions in the HQQ GitHub repository.")
logger.info(f"Loading HQQ model with backend: \"{shared.args.hqq_backend}\"") logger.info(f"Loading HQQ model with backend: \"{shared.args.hqq_backend}\"")
@ -334,7 +340,10 @@ def HQQ_loader(model_name):
def TensorRT_LLM_loader(model_name): def TensorRT_LLM_loader(model_name):
from modules.tensorrt_llm import TensorRTLLMModel try:
from modules.tensorrt_llm import TensorRTLLMModel
except ModuleNotFoundError:
raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.")
model = TensorRTLLMModel.from_pretrained(model_name) model = TensorRTLLMModel.from_pretrained(model_name)
return model return model