From 2b75d725e636c84add48a89d1186d6374e2a6724 Mon Sep 17 00:00:00 2001 From: Haotian Liu <6631389+haotian-liu@users.noreply.github.com> Date: Tue, 10 Oct 2023 16:40:52 -0500 Subject: [PATCH] Initial support for LLaVA-LLaMA-2. (#3377) --- .../multimodal/pipelines/llava/llava.py | 29 +++++++++++++++++++ .../multimodal/pipelines/llava/pipelines.py | 9 +++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/extensions/multimodal/pipelines/llava/llava.py b/extensions/multimodal/pipelines/llava/llava.py index 306ab227..3c75eeed 100644 --- a/extensions/multimodal/pipelines/llava/llava.py +++ b/extensions/multimodal/pipelines/llava/llava.py @@ -146,3 +146,32 @@ class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline): @staticmethod def llava_projector_repo() -> str: return "liuhaotian/LLaVA-7b-delta-v0" + + +class LLaVA_LLaMA_2_13B_Pipeline(LLaVA_v0_13B_Pipeline): + def __init__(self, params: dict) -> None: + super().__init__(params) + + @staticmethod + def name() -> str: + return "llava-llama-2-13b" + + @staticmethod + def placeholder_token_id() -> int: + return 0 + + @staticmethod + def llava_projector_repo() -> str: + return "liuhaotian/llava-llama-2-13b-chat-lightning-preview" + + @staticmethod + def image_start() -> str: + return "" + + @staticmethod + def image_end() -> str: + return "" + + @staticmethod + def placeholder_embeddings() -> torch.Tensor: + return LLaVA_v0_Pipeline.embed_tokens(encode(""*256, add_bos_token=False)[0]) diff --git a/extensions/multimodal/pipelines/llava/pipelines.py b/extensions/multimodal/pipelines/llava/pipelines.py index 0f650c1a..c6776a5a 100644 --- a/extensions/multimodal/pipelines/llava/pipelines.py +++ b/extensions/multimodal/pipelines/llava/pipelines.py @@ -2,7 +2,7 @@ from typing import Optional from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline -available_pipelines = ['llava-7b', 'llava-13b'] +available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b'] def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]: @@ -12,12 +12,19 @@ def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline if name == 'llava-13b': from .llava import LLaVA_v0_13B_Pipeline return LLaVA_v0_13B_Pipeline(params) + if name == 'llava-llama-2-13b': + from .llava import LLaVA_LLaMA_2_13B_Pipeline + return LLaVA_LLaMA_2_13B_Pipeline(params) return None def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]: if 'llava' not in model_name.lower(): return None + if 'llama-2' in model_name.lower(): + if '13b' in model_name.lower(): + from .llava import LLaVA_LLaMA_2_13B_Pipeline + return LLaVA_LLaMA_2_13B_Pipeline(params) if '7b' in model_name.lower(): from .llava import LLaVA_v0_7B_Pipeline return LLaVA_v0_7B_Pipeline(params)