mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
Initial support for LLaVA-LLaMA-2. (#3377)
This commit is contained in:
parent
9fab9a1ca6
commit
2b75d725e6
@ -146,3 +146,32 @@ class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/LLaVA-7b-delta-v0"
|
||||
|
||||
|
||||
class LLaVA_LLaMA_2_13B_Pipeline(LLaVA_v0_13B_Pipeline):
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__(params)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "llava-llama-2-13b"
|
||||
|
||||
@staticmethod
|
||||
def placeholder_token_id() -> int:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/llava-llama-2-13b-chat-lightning-preview"
|
||||
|
||||
@staticmethod
|
||||
def image_start() -> str:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def image_end() -> str:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*256, add_bos_token=False)[0])
|
||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
|
||||
available_pipelines = ['llava-7b', 'llava-13b']
|
||||
available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b']
|
||||
|
||||
|
||||
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
@ -12,12 +12,19 @@ def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline
|
||||
if name == 'llava-13b':
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
if name == 'llava-llama-2-13b':
|
||||
from .llava import LLaVA_LLaMA_2_13B_Pipeline
|
||||
return LLaVA_LLaMA_2_13B_Pipeline(params)
|
||||
return None
|
||||
|
||||
|
||||
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
if 'llava' not in model_name.lower():
|
||||
return None
|
||||
if 'llama-2' in model_name.lower():
|
||||
if '13b' in model_name.lower():
|
||||
from .llava import LLaVA_LLaMA_2_13B_Pipeline
|
||||
return LLaVA_LLaMA_2_13B_Pipeline(params)
|
||||
if '7b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
|
Loading…
Reference in New Issue
Block a user