mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Merge branch 'main' into dev
This commit is contained in:
commit
878250d609
@ -56,7 +56,12 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
if hasattr(shared.model.model, 'embed_tokens'):
|
||||||
|
func = shared.model.model.embed_tokens
|
||||||
|
else:
|
||||||
|
func = shared.model.model.model.embed_tokens # AutoGPTQ case
|
||||||
|
|
||||||
|
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def placeholder_embeddings() -> torch.Tensor:
|
def placeholder_embeddings() -> torch.Tensor:
|
||||||
|
Loading…
Reference in New Issue
Block a user