mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Initial support for LLaVA-LLaMA-2. (#3377)
This commit is contained in:
parent
9fab9a1ca6
commit
2b75d725e6
@ -146,3 +146,32 @@ class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def llava_projector_repo() -> str:
|
def llava_projector_repo() -> str:
|
||||||
return "liuhaotian/LLaVA-7b-delta-v0"
|
return "liuhaotian/LLaVA-7b-delta-v0"
|
||||||
|
|
||||||
|
|
||||||
|
class LLaVA_LLaMA_2_13B_Pipeline(LLaVA_v0_13B_Pipeline):
|
||||||
|
def __init__(self, params: dict) -> None:
|
||||||
|
super().__init__(params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def name() -> str:
|
||||||
|
return "llava-llama-2-13b"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def placeholder_token_id() -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def llava_projector_repo() -> str:
|
||||||
|
return "liuhaotian/llava-llama-2-13b-chat-lightning-preview"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def image_start() -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def image_end() -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def placeholder_embeddings() -> torch.Tensor:
|
||||||
|
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*256, add_bos_token=False)[0])
|
||||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||||
|
|
||||||
available_pipelines = ['llava-7b', 'llava-13b']
|
available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b']
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||||
@ -12,12 +12,19 @@ def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline
|
|||||||
if name == 'llava-13b':
|
if name == 'llava-13b':
|
||||||
from .llava import LLaVA_v0_13B_Pipeline
|
from .llava import LLaVA_v0_13B_Pipeline
|
||||||
return LLaVA_v0_13B_Pipeline(params)
|
return LLaVA_v0_13B_Pipeline(params)
|
||||||
|
if name == 'llava-llama-2-13b':
|
||||||
|
from .llava import LLaVA_LLaMA_2_13B_Pipeline
|
||||||
|
return LLaVA_LLaMA_2_13B_Pipeline(params)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||||
if 'llava' not in model_name.lower():
|
if 'llava' not in model_name.lower():
|
||||||
return None
|
return None
|
||||||
|
if 'llama-2' in model_name.lower():
|
||||||
|
if '13b' in model_name.lower():
|
||||||
|
from .llava import LLaVA_LLaMA_2_13B_Pipeline
|
||||||
|
return LLaVA_LLaMA_2_13B_Pipeline(params)
|
||||||
if '7b' in model_name.lower():
|
if '7b' in model_name.lower():
|
||||||
from .llava import LLaVA_v0_7B_Pipeline
|
from .llava import LLaVA_v0_7B_Pipeline
|
||||||
return LLaVA_v0_7B_Pipeline(params)
|
return LLaVA_v0_7B_Pipeline(params)
|
||||||
|
Loading…
Reference in New Issue
Block a user