diff --git a/extensions/multimodal/pipelines/llava/llava.py b/extensions/multimodal/pipelines/llava/llava.py index 16f0e06f..eca2be50 100644 --- a/extensions/multimodal/pipelines/llava/llava.py +++ b/extensions/multimodal/pipelines/llava/llava.py @@ -56,7 +56,12 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline): @staticmethod def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor: - return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype) + if hasattr(shared.model.model, 'embed_tokens'): + func = shared.model.model.embed_tokens + else: + func = shared.model.model.model.embed_tokens # AutoGPTQ case + + return func(input_ids).to(shared.model.device, dtype=shared.model.dtype) @staticmethod def placeholder_embeddings() -> torch.Tensor: