Merge branch 'main' into dev

This commit is contained in:
oobabooga 2023-06-06 19:43:53 -03:00
commit 878250d609

View File

@ -56,7 +56,12 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
@staticmethod @staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor: def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype) if hasattr(shared.model.model, 'embed_tokens'):
func = shared.model.model.embed_tokens
else:
func = shared.model.model.model.embed_tokens # AutoGPTQ case
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)
@staticmethod @staticmethod
def placeholder_embeddings() -> torch.Tensor: def placeholder_embeddings() -> torch.Tensor: