diff --git a/modules/AutoGPTQ_loader.py b/modules/AutoGPTQ_loader.py index 5b87fe56..5f2fd71b 100644 --- a/modules/AutoGPTQ_loader.py +++ b/modules/AutoGPTQ_loader.py @@ -51,4 +51,15 @@ def load_quantized(model_name): logger.info(f"The AutoGPTQ params are: {params}") model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params) + + # These lines fix the multimodal extension when used with AutoGPTQ + if not hasattr(model, 'dtype'): + model.dtype = model.model.dtype + + if not hasattr(model, 'embed_tokens'): + model.embed_tokens = model.model.model.embed_tokens + + if not hasattr(model.model, 'embed_tokens'): + model.model.embed_tokens = model.model.model.embed_tokens + return model