Initial support for LLaVA-LLaMA-2. (#3377)

This commit is contained in:
Haotian Liu 2023-10-10 16:40:52 -05:00 committed by GitHub
parent 9fab9a1ca6
commit 2b75d725e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 1 deletions

View File

@ -146,3 +146,32 @@ class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
@staticmethod @staticmethod
def llava_projector_repo() -> str: def llava_projector_repo() -> str:
return "liuhaotian/LLaVA-7b-delta-v0" return "liuhaotian/LLaVA-7b-delta-v0"
class LLaVA_LLaMA_2_13B_Pipeline(LLaVA_v0_13B_Pipeline):
def __init__(self, params: dict) -> None:
super().__init__(params)
@staticmethod
def name() -> str:
return "llava-llama-2-13b"
@staticmethod
def placeholder_token_id() -> int:
return 0
@staticmethod
def llava_projector_repo() -> str:
return "liuhaotian/llava-llama-2-13b-chat-lightning-preview"
@staticmethod
def image_start() -> str:
return ""
@staticmethod
def image_end() -> str:
return ""
@staticmethod
def placeholder_embeddings() -> torch.Tensor:
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*256, add_bos_token=False)[0])

View File

@ -2,7 +2,7 @@ from typing import Optional
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
available_pipelines = ['llava-7b', 'llava-13b'] available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b']
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]: def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
@ -12,12 +12,19 @@ def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline
if name == 'llava-13b': if name == 'llava-13b':
from .llava import LLaVA_v0_13B_Pipeline from .llava import LLaVA_v0_13B_Pipeline
return LLaVA_v0_13B_Pipeline(params) return LLaVA_v0_13B_Pipeline(params)
if name == 'llava-llama-2-13b':
from .llava import LLaVA_LLaMA_2_13B_Pipeline
return LLaVA_LLaMA_2_13B_Pipeline(params)
return None return None
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]: def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
if 'llava' not in model_name.lower(): if 'llava' not in model_name.lower():
return None return None
if 'llama-2' in model_name.lower():
if '13b' in model_name.lower():
from .llava import LLaVA_LLaMA_2_13B_Pipeline
return LLaVA_LLaMA_2_13B_Pipeline(params)
if '7b' in model_name.lower(): if '7b' in model_name.lower():
from .llava import LLaVA_v0_7B_Pipeline from .llava import LLaVA_v0_7B_Pipeline
return LLaVA_v0_7B_Pipeline(params) return LLaVA_v0_7B_Pipeline(params)