Support LLaVA v1.5 (#4305)

This commit is contained in:
Haotian Liu 2023-10-20 00:28:14 -05:00 committed by GitHub
parent bb71272903
commit 32984ea2f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 18 deletions

View File

@ -4,6 +4,8 @@
Adds support for multimodality (text+images) to text-generation-webui. Adds support for multimodality (text+images) to text-generation-webui.
Note: multimodal currently only works for transformers, AutoGPTQ, and GPTQ-for-LLaMa loaders. ExLlama (v1 and v2) and llama.cpp support are planned.
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4 https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
## Usage ## Usage
@ -11,13 +13,15 @@ https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b
To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples: To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples:
``` ```
python server.py --model liuhaotian_llava-v1.5-13b --multimodal-pipeline llava-v1.5-13b --load-in-4bit
python server.py --model TheBloke_llava-v1.5-13B-GPTQ_gptq-4bit-32g-actorder_True --multimodal-pipeline llava-v1.5-13b --disable_exllama --loader autogptq
python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b
python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b python server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b
python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b
python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b
``` ```
There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`: There is built-in support for LLaVA-v0-13B, LLaVA-v0-7b, and LLaVA-v1.5-13B. To install `minigpt4`:
- clone https://github.com/Wojtab/minigpt-4-pipeline into `extensions/multimodal/pipelines` - clone https://github.com/Wojtab/minigpt-4-pipeline into `extensions/multimodal/pipelines`
- install the requirements.txt - install the requirements.txt
@ -31,6 +35,7 @@ To send an image, just upload it to the extension field below chat, and send a p
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. It seems as if some multimodal networks consider the features in all images at the same time as if they were a single image. Due to this behavior, by default, the extension skips previous images. However, it can lead to sub-par generation on other pipelines. If you want to include all images, just tick this checkbox. Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. It seems as if some multimodal networks consider the features in all images at the same time as if they were a single image. Due to this behavior, by default, the extension skips previous images. However, it can lead to sub-par generation on other pipelines. If you want to include all images, just tick this checkbox.
## Compatibility ## Compatibility
As of now, the following multimodal pipelines are supported: As of now, the following multimodal pipelines are supported:
|Pipeline|`--multimodal-pipeline`|Default LLM|LLM info(for the linked model)|Pipeline repository| |Pipeline|`--multimodal-pipeline`|Default LLM|LLM info(for the linked model)|Pipeline repository|
|-|-|-|-|-| |-|-|-|-|-|

View File

@ -13,6 +13,20 @@ from modules.logging_colors import logger
from modules.text_generation import encode from modules.text_generation import encode
def expand2square(pil_img: Image.Image, background_color: Tuple[int]) -> Image.Image:
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline): class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
CLIP_REPO = "openai/clip-vit-large-patch14" CLIP_REPO = "openai/clip-vit-large-patch14"
@ -27,21 +41,33 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
def _load_models(self): def _load_models(self):
start_ts = time.time() start_ts = time.time()
logger.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...") logger.info(f"LLaVA - Loading CLIP from {self.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype) image_processor = CLIPImageProcessor.from_pretrained(self.CLIP_REPO, torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device) vision_tower = CLIPVisionModel.from_pretrained(self.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
logger.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...") logger.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename()) projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
mm_projector = torch.nn.Linear(*self.llava_projector_shape()) mm_projector = self.build_mm_projector()
projector_data = torch.load(projector_path) projector_data = torch.load(projector_path)
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False) projector_data = {k[19:]: v for k, v in projector_data.items() if k.startswith('model.mm_projector.')}
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False) mm_projector.load_state_dict(projector_data)
mm_projector = mm_projector.to(self.projector_device) mm_projector = mm_projector.to(self.projector_device)
logger.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds") logger.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
return image_processor, vision_tower, mm_projector return image_processor, vision_tower, mm_projector
def build_mm_projector(self) -> torch.nn.Module:
projector_shape = self.llava_projector_shape()
if len(projector_shape) == 2:
return torch.nn.Linear(*projector_shape)
else:
modules = []
modules.append(torch.nn.Linear(projector_shape[0], projector_shape[1]))
for i in range(2, len(projector_shape)):
modules.append(torch.nn.GELU())
modules.append(torch.nn.Linear(projector_shape[i-1], projector_shape[i]))
return torch.nn.Sequential(*modules)
@staticmethod @staticmethod
def image_start() -> str: def image_start() -> str:
return "<im_start>" return "<im_start>"
@ -175,3 +201,50 @@ class LLaVA_LLaMA_2_13B_Pipeline(LLaVA_v0_13B_Pipeline):
@staticmethod @staticmethod
def placeholder_embeddings() -> torch.Tensor: def placeholder_embeddings() -> torch.Tensor:
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*256, add_bos_token=False)[0]) return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*256, add_bos_token=False)[0])
class LLaVA_v1_5_13B_Pipeline(LLaVA_v0_13B_Pipeline):
CLIP_REPO = "openai/clip-vit-large-patch14-336"
def __init__(self, params: dict) -> None:
super().__init__(params)
@staticmethod
def name() -> str:
return "llava-v1.5-13b"
@staticmethod
def llava_projector_shape() -> Tuple[int, int]:
return (1024, 5120, 5120)
@staticmethod
def placeholder_token_id() -> int:
return 0
@staticmethod
def llava_projector_repo() -> str:
return "liuhaotian/llava-v1.5-13b"
@staticmethod
def image_start() -> str:
return ""
@staticmethod
def image_end() -> str:
return ""
@staticmethod
def num_image_embeds() -> int:
return 576
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
# pad it to square first
images = [
expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
for image in images
]
return super().embed_images(images)
@staticmethod
def placeholder_embeddings() -> torch.Tensor:
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*576, add_bos_token=False)[0])

View File

@ -2,7 +2,7 @@ from typing import Optional
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b'] available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b', 'llava-v1.5-13b']
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]: def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
@ -15,6 +15,9 @@ def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline
if name == 'llava-llama-2-13b': if name == 'llava-llama-2-13b':
from .llava import LLaVA_LLaMA_2_13B_Pipeline from .llava import LLaVA_LLaMA_2_13B_Pipeline
return LLaVA_LLaMA_2_13B_Pipeline(params) return LLaVA_LLaMA_2_13B_Pipeline(params)
if name == 'llava-v1.5-13b':
from .llava import LLaVA_v1_5_13B_Pipeline
return LLaVA_v1_5_13B_Pipeline(params)
return None return None
@ -25,10 +28,15 @@ def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[Abst
if '13b' in model_name.lower(): if '13b' in model_name.lower():
from .llava import LLaVA_LLaMA_2_13B_Pipeline from .llava import LLaVA_LLaMA_2_13B_Pipeline
return LLaVA_LLaMA_2_13B_Pipeline(params) return LLaVA_LLaMA_2_13B_Pipeline(params)
if '7b' in model_name.lower(): elif 'llava-v1.5' in model_name.lower():
from .llava import LLaVA_v0_7B_Pipeline if '13b' in model_name.lower():
return LLaVA_v0_7B_Pipeline(params) from .llava import LLaVA_v1_5_13B_Pipeline
if '13b' in model_name.lower(): return LLaVA_v1_5_13B_Pipeline(params)
from .llava import LLaVA_v0_13B_Pipeline else:
return LLaVA_v0_13B_Pipeline(params) if '7b' in model_name.lower():
from .llava import LLaVA_v0_7B_Pipeline
return LLaVA_v0_7B_Pipeline(params)
if '13b' in model_name.lower():
from .llava import LLaVA_v0_13B_Pipeline
return LLaVA_v0_13B_Pipeline(params)
return None return None

View File

@ -46,23 +46,24 @@ def chat_input_modifier(text, visible_text, state):
def add_chat_picture(picture, text, visible_text): def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable) # resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
# Adjusted to 336 for the values here, due to the increased resolution in llava-v1.5
max_hw, min_hw = max(picture.size), min(picture.size) max_hw, min_hw = max(picture.size), min(picture.size)
aspect_ratio = max_hw / min_hw aspect_ratio = max_hw / min_hw
shortest_edge = int(max(300 / aspect_ratio, 224)) shortest_edge = int(max(336 / aspect_ratio, 336))
longest_edge = int(shortest_edge * aspect_ratio) longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w, h)) picture = picture.resize((w, h))
buffer = BytesIO() buffer = BytesIO()
picture.save(buffer, format="JPEG") picture.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">' image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text: if '<image>' in text:
text = text.replace('<image>', image) text = text.replace('<image>', image)
else: else:
text = text + '\n' + image text = image + '\n' + text
if visible_text == '' or visible_text is None: if visible_text == '' or visible_text is None:
visible_text = text visible_text = text

View File

@ -0,0 +1,4 @@
user: "USER:"
bot: "ASSISTANT:"
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|></s>\n"
context: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"

View File

@ -39,6 +39,8 @@
.*llava: .*llava:
instruction_template: 'LLaVA' instruction_template: 'LLaVA'
custom_stopping_strings: '"\n###"' custom_stopping_strings: '"\n###"'
.*llava.*1.5:
instruction_template: 'LLaVA-v1'
.*wizard.*mega: .*wizard.*mega:
instruction_template: 'Wizard-Mega' instruction_template: 'Wizard-Mega'
custom_stopping_strings: '"</s>"' custom_stopping_strings: '"</s>"'