Make llava/minigpt-4 work with AutoGPTQ

This commit is contained in:
oobabooga 2023-06-11 17:52:23 -03:00
parent f4defde752
commit e471919e6d

View File

@ -51,4 +51,15 @@ def load_quantized(model_name):
logger.info(f"The AutoGPTQ params are: {params}")
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
# These lines fix the multimodal extension when used with AutoGPTQ
if not hasattr(model, 'dtype'):
model.dtype = model.model.dtype
if not hasattr(model, 'embed_tokens'):
model.embed_tokens = model.model.model.embed_tokens
if not hasattr(model.model, 'embed_tokens'):
model.model.embed_tokens = model.model.model.embed_tokens
return model