Add some checks to AutoGPTQ loader

This commit is contained in:
oobabooga 2023-06-14 18:44:43 -03:00
parent 134430bbe2
commit 4d508cbe58

View File

@ -53,13 +53,16 @@ def load_quantized(model_name):
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
# These lines fix the multimodal extension when used with AutoGPTQ
if not hasattr(model, 'dtype'):
model.dtype = model.model.dtype
if hasattr(model, 'model'):
if not hasattr(model, 'dtype'):
if hasattr(model.model, 'dtype'):
model.dtype = model.model.dtype
if not hasattr(model, 'embed_tokens'):
model.embed_tokens = model.model.model.embed_tokens
if hasattr(model.model, 'model') and hasattr(model.model.model, 'embed_tokens'):
if not hasattr(model, 'embed_tokens'):
model.embed_tokens = model.model.model.embed_tokens
if not hasattr(model.model, 'embed_tokens'):
model.model.embed_tokens = model.model.model.embed_tokens
if not hasattr(model.model, 'embed_tokens'):
model.model.embed_tokens = model.model.model.embed_tokens
return model