mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-23 01:59:21 +01:00
Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)
This commit is contained in:
parent
a2b25322f0
commit
e9e75a9ec7
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,6 +4,7 @@ training/datasets
|
||||
extensions/silero_tts/outputs
|
||||
extensions/elevenlabs_tts/outputs
|
||||
extensions/sd_api_pictures/outputs
|
||||
extensions/multimodal/pipelines
|
||||
logs
|
||||
loras
|
||||
models
|
||||
|
@ -31,6 +31,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
||||
* [llama.cpp](docs/llama.cpp-models.md)
|
||||
* [RWKV model](docs/RWKV-model.md)
|
||||
* [LoRA (loading and training)](docs/Using-LoRAs.md)
|
||||
* [Multimodal pipelines, including LLaVA and MiniGPT-4](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal)
|
||||
* Softprompts
|
||||
* [Extensions](docs/Extensions.md) - see the [user extensions list](https://github.com/oobabooga/text-generation-webui-extensions)
|
||||
|
||||
@ -281,6 +282,12 @@ Optionally, you can use the following command-line flags:
|
||||
| `--api` | Enable the API extension. |
|
||||
| `--public-api` | Create a public URL for the API using Cloudfare. |
|
||||
|
||||
#### Multimodal
|
||||
|
||||
| Flag | Description |
|
||||
|---------------------------------------|-------------|
|
||||
| `--multimodal-pipeline PIPELINE` | The multimodal pipeline to use. Examples: `llava-7b`, `llava-13b`. |
|
||||
|
||||
Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md).
|
||||
|
||||
## Presets
|
||||
|
@ -1,4 +1,4 @@
|
||||
user: "### Human"
|
||||
bot: "### Assistant"
|
||||
turn_template: "<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n"
|
||||
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
|
||||
user: "### Human:"
|
||||
bot: "### Assistant:"
|
||||
turn_template: "<|user|> <|user-message|><|bot|> <|bot-message|>\n"
|
||||
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.### Human: Hi!### Assistant: Hi there! How can I help you today?\n"
|
||||
|
@ -33,7 +33,7 @@ Most of these have been created by the extremely talented contributors that you
|
||||
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|
||||
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|
||||
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|
||||
|[llava](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava) | Adds LLaVA multimodal model support. For detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava/README.md) in the extension directory. |
|
||||
|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. |
|
||||
|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. |
|
||||
|[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. |
|
||||
|
||||
@ -50,7 +50,8 @@ Most of these have been created by the extremely talented contributors that you
|
||||
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
|
||||
| `def custom_generate_reply(...)` | Overrides the main text generation function. |
|
||||
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
|
||||
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
|
||||
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `multimodal` extension for an example |
|
||||
| `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See `multimodal` extension for an example |
|
||||
|
||||
Additionally, the script may define two special global variables:
|
||||
|
||||
@ -78,7 +79,7 @@ input_hijack = {
|
||||
```
|
||||
This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example.
|
||||
|
||||
Additionally, your extension can set the value to be a callback, in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `llava` extension above for an example.
|
||||
Additionally, your extension can set the value to be a callback, in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `multimodal` extension above for an example.
|
||||
|
||||
## The `bot_prefix_modifier`
|
||||
|
||||
@ -100,13 +101,22 @@ Marie Antoinette will become very enthusiastic in all her messages.
|
||||
|
||||
In order to use your extension, you must start the web UI with the `--extensions` flag followed by the name of your extension (the folder under `text-generation-webui/extension` where `script.py` resides).
|
||||
|
||||
You can activate more than one extension at a time by providing their names separated by spaces. The input, output and bot prefix modifiers will be applied in the specified order. For `custom_generate_chat_prompt`, only the first declaration encountered will be used and the rest will be ignored.
|
||||
You can activate more than one extension at a time by providing their names separated by spaces. The input, output and bot prefix modifiers will be applied in the specified order.
|
||||
|
||||
|
||||
```
|
||||
python server.py --extensions enthusiasm translate # First apply enthusiasm, then translate
|
||||
python server.py --extensions translate enthusiasm # First apply translate, then enthusiasm
|
||||
```
|
||||
|
||||
Do note, that for:
|
||||
- `custom_generate_chat_prompt`
|
||||
- `custom_generate_reply`
|
||||
- `tokenizer_modifier`
|
||||
- `custom_tokenized_length`
|
||||
|
||||
only the first declaration encountered will be used and the rest will be ignored.
|
||||
|
||||
## `custom_generate_reply` example
|
||||
|
||||
Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet.
|
||||
@ -167,7 +177,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
|
||||
# Building the prompt
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||
while i >= 0 and get_encoded_length(''.join(rows)) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
|
||||
else:
|
||||
@ -190,7 +200,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
||||
|
||||
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||
while len(rows) > min_rows and get_encoded_length(''.join(rows)) >= max_length:
|
||||
rows.pop(1)
|
||||
|
||||
prompt = ''.join(rows)
|
||||
|
@ -3,7 +3,7 @@ import traceback
|
||||
from threading import Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
from modules.text_generation import encode
|
||||
from modules.text_generation import get_encoded_length
|
||||
|
||||
|
||||
def build_parameters(body):
|
||||
@ -11,7 +11,7 @@ def build_parameters(body):
|
||||
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
|
||||
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
|
@ -1,71 +0,0 @@
|
||||
# LLaVA
|
||||
|
||||
## Description
|
||||
Adds [LLaVA 13B](https://github.com/haotian-liu/LLaVA) multimodality support to text-generation-webui.
|
||||
|
||||
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
|
||||
|
||||
## LLaVA-7B
|
||||
7B version currently isn't supported. It will be supported if/when [more generic multimodality support](https://github.com/oobabooga/text-generation-webui/discussions/1687) gets implemented.
|
||||
|
||||
## Usage
|
||||
To run this extension, download LLaVA weights, for example from [here](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g) (note: it's a 4-bit [GPTQ quantization](https://github.com/oobabooga/text-generation-webui/tree/main/docs/GPTQ-models-(4-bit-mode).md), done on "old CUDA" branch), and then start server.py with `--extensions llava` argument.
|
||||
|
||||
Do note, that each image takes up 258 tokens, so adjust max_new_tokens to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
|
||||
|
||||
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
|
||||
|
||||
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. From initial testing, it seems as LLaVA considers the features in all images at the same time, so by default the extension skips previous images. If you want to include them anyway, just tick this checkbox.
|
||||
|
||||
## Extension config
|
||||
This extension uses following parameters (from settings.json):
|
||||
|Parameter|Description|
|
||||
|---------|-----------|
|
||||
|`llava-clip_bits`|Number of bits to load CLIP feature extractor in (either 32 or 16, default=32)|
|
||||
|`llava-clip_device`|Torch device to run the extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`llava-clip_repo`|Huggingface repository of CLIP model, `openai/clip-vit-large-patch14` by default. There should be no need to change it|
|
||||
|`llava-projector_bits`|Number of bits to load CLIP->LLaMA feature projector in (either 32 or 16, default=32)|
|
||||
|`llava-projector_device`|Torch device to run the CLIP->LLaMA feature projector on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`llava-projector_repo`|Huggingface repository of multimodal projector, `liuhaotian/LLaVA-13b-delta-v0` by default. There should be no need to change it|
|
||||
|`llava-projector_filename`|The filename of multimodal projector weights, `mm_projector.bin` by default. There should be no need to change it|
|
||||
|`llava-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
|
||||
## Technical description
|
||||
|
||||
### Original LLaVA
|
||||
The default LLaVA implementation uses modified `transformers` library, however this extension forgoes this requirement. The transformers are modified in LLaVA in such a way, that the entire LLaVA model gets loaded, and the inference now looks as follows:
|
||||
```
|
||||
images --> CLIP --> projector --> input embeddings for images --> |
|
||||
| --> LLaMA
|
||||
prompt -------------------------> input embeddings for text ----> |
|
||||
```
|
||||
The images are represented in the prompt by the following token IDs:
|
||||
- 32000 - `<im_patch>` - placeholder token for embeddings from projector
|
||||
- 32001 - `<im_start>` - token marking start of an image
|
||||
- 32002 - `<im_end>` - token marking end of an image
|
||||
|
||||
By default, image will be represented as `<im_start><im_patch>*256<im_end>`. The input embeddings for an image are converted with a single linear layer of the projector, then they are placed instead of `<im_patch>` tokens.
|
||||
The concatenated prompt then gets fed to fine-tuned LLaMA.
|
||||
|
||||
### In this extension
|
||||
|
||||
Using default transformers, they only load the LLaMA part of LLaVA, ignoring the added projector weights, and not loading CLIP. We then reconstruct the `images -> CLIP -> projector` pipeline ourselves, then concatenate the input embeddings, and feed it to LLaMA loaded by transformers. This allows us to use normal flow from webui to load this model, and just hijack the model input with additional features.
|
||||
Splitting it to 3 separate models, allows us to configure each of them, and to move them to different devices(for example we can run CLIP+projector on CPU and LLaMA on GPU). Also, it enables us to use 4-bit GPTQ quantization for LLaVA, massively cutting down the VRAM requirement (it should be possible to fit on 12GB of VRAM with full context size by moving CLIP and projector to CPU).
|
||||
|
||||
### Usage through API
|
||||
|
||||
You can run the multimodal inference through API, by inputting the images to prompt. Images are embedded like so: `f'<img src="data:image/jpeg;base64,{img_str}">'`, where `img_str` is base-64 jpeg data. Python example:
|
||||
```Python
|
||||
import base64
|
||||
import requests
|
||||
|
||||
CONTEXT = "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
|
||||
|
||||
with open('extreme_ironing.jpg', 'rb') as f:
|
||||
img_str = base64.b64encode(f.read()).decode('utf-8')
|
||||
prompt = CONTEXT + f'### Human: \nWhat is unusual about this image: \n<img src="data:image/jpeg;base64,{img_str}">\n### Assistant: \n'
|
||||
print(requests.post('http://127.0.0.1:5000/api/v1/generate', json={'prompt': prompt, 'stopping_strings': ['\n###']}).json())
|
||||
```
|
||||
script output:
|
||||
```Python
|
||||
{'results': [{'text': "The unusual aspect of this image is that a man is standing on top of a yellow minivan while doing his laundry. He has set up a makeshift clothes line using the car's rooftop as an outdoor drying area. This scene is uncommon because people typically do their laundry indoors, in a dedicated space like a laundromat or a room in their home, rather than on top of a moving vehicle. Additionally, hanging clothes on the car could be potentially hazardous or illegal in some jurisdictions due to the risk of damaging the vehicle or causing accidents on the road.\n##"}]}
|
||||
```
|
@ -1,272 +1,6 @@
|
||||
import base64
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
from modules import shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.text_generation import encode, get_max_prompt_length
|
||||
|
||||
params = {
|
||||
"add_all_images_to_prompt": False,
|
||||
# device to run CLIP on
|
||||
"clip_device": None,
|
||||
# bits to load clip in either 32 or 16 (it doesn't support 8-bit)
|
||||
"clip_bits": 32,
|
||||
# clip repository
|
||||
"clip_repo": "openai/clip-vit-large-patch14",
|
||||
# device to run projector on
|
||||
"projector_device": None,
|
||||
# projector bits, either 32 or 16
|
||||
"projector_bits": 32,
|
||||
# projector repository
|
||||
"projector_repo": "liuhaotian/LLaVA-13b-delta-v0",
|
||||
# file with the projector weights
|
||||
"projector_file": "mm_projector.bin"
|
||||
}
|
||||
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation
|
||||
input_hijack = {
|
||||
'state': False,
|
||||
'value': ["", ""]
|
||||
}
|
||||
|
||||
|
||||
# initialized in ui, so that params are loaded from settings
|
||||
llava_embedder = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
token: str
|
||||
id: int
|
||||
|
||||
|
||||
class LLaVAEmbedder:
|
||||
IM_PATCH = Token("<im_patch>", 32000)
|
||||
IM_START = Token("<im_start>", 32001)
|
||||
IM_END = Token("<im_end>", 32002)
|
||||
|
||||
def __init__(self):
|
||||
self.clip_device = self._get_device("clip_device")
|
||||
self.clip_dtype = self._get_dtype("clip_bits")
|
||||
self.projector_device = self._get_device("projector_device")
|
||||
self.projector_dtype = self._get_dtype("projector_bits")
|
||||
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
|
||||
|
||||
def _get_device(self, setting_name):
|
||||
if params[setting_name] is None:
|
||||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(params[setting_name])
|
||||
|
||||
def _get_dtype(self, setting_name):
|
||||
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
|
||||
|
||||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
print(f"LLaVA - Loading CLIP from {params['clip_repo']} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
print(f"LLaVA - Loading projector from {params['projector_repo']} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(params["projector_repo"], params["projector_file"])
|
||||
mm_projector = torch.nn.Linear(1024, 5120)
|
||||
projector_data = torch.load(projector_path)
|
||||
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector = mm_projector.to(self.projector_device)
|
||||
|
||||
print(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
return image_processor, vision_tower, mm_projector
|
||||
|
||||
def _update_prompt(self, prompt, images):
|
||||
for _ in images:
|
||||
# replace the image token with the image patch token in the prompt (each occurrence)
|
||||
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
|
||||
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
|
||||
prompt = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', replace_token, prompt, 1)
|
||||
return prompt
|
||||
|
||||
def _extract_image_features(self, images):
|
||||
images = self.image_processor(images, return_tensors='pt')['pixel_values']
|
||||
images = images.to(self.clip_device, dtype=self.clip_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
|
||||
select_hidden_state_layer = -2
|
||||
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
||||
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
|
||||
image_features = self.mm_projector(image_features)
|
||||
return image_features
|
||||
|
||||
def forward(self, prompt, images, state):
|
||||
prompt = self._update_prompt(prompt, images)
|
||||
input_ids = encode(prompt, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))[0]
|
||||
input_embeds = shared.model.model.embed_tokens(input_ids).to(self.projector_device)
|
||||
|
||||
if input_ids[0] == LLaVAEmbedder.IM_PATCH.id:
|
||||
# prompt got truncated in the middle of an image, remove the image data
|
||||
im_end = torch.where(input_ids == LLaVAEmbedder.IM_END.id)[0][0]
|
||||
input_ids = input_ids[im_end+1:]
|
||||
input_embeds = input_embeds[im_end+1:]
|
||||
leftover_images = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0].shape[0]
|
||||
print(f"LLaVA - WARNING: removed {len(images) - leftover_images} image(s) from prompt. The generation might be broken, try decreasing max_new_tokens")
|
||||
images = images[-leftover_images:]
|
||||
if len(images) == 0:
|
||||
return prompt, input_ids, input_embeds, 0
|
||||
|
||||
total_embedded = 0
|
||||
image_features = self._extract_image_features(images).to(self.projector_device)
|
||||
image_start_tokens = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0]
|
||||
|
||||
if not torch.any(input_ids == LLaVAEmbedder.IM_PATCH.id) or len(image_start_tokens) == 0:
|
||||
# multimodal LLM, but the current prompt is not multimodal/truncated
|
||||
return prompt, input_ids, input_embeds, total_embedded
|
||||
|
||||
cur_image_idx = 0
|
||||
if not params['add_all_images_to_prompt']:
|
||||
image_start_tokens = [image_start_tokens[-1]]
|
||||
cur_image_idx = -1
|
||||
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
input_embeds = torch.cat((input_embeds[:image_start_token_pos+1], cur_image_features, input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
||||
cur_image_idx += 1
|
||||
total_embedded += 1
|
||||
|
||||
return prompt, input_ids, input_embeds, total_embedded
|
||||
|
||||
@staticmethod
|
||||
def len_in_tokens(text):
|
||||
images = re.findall(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', text)
|
||||
image_tokens = 0
|
||||
for _ in images:
|
||||
image_tokens += 258
|
||||
return len(encode(re.sub(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', '', text))[0]) + image_tokens
|
||||
|
||||
|
||||
def add_chat_picture(picture, text, visible_text):
|
||||
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
|
||||
max_hw, min_hw = max(picture.size), min(picture.size)
|
||||
aspect_ratio = max_hw / min_hw
|
||||
shortest_edge = int(max(300 / aspect_ratio, 224))
|
||||
longest_edge = int(shortest_edge * aspect_ratio)
|
||||
w = shortest_edge if picture.width < picture.height else longest_edge
|
||||
h = shortest_edge if picture.width >= picture.height else longest_edge
|
||||
picture = picture.resize((w,h))
|
||||
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
|
||||
|
||||
if '<image>' in text:
|
||||
text = text.replace('<image>', image)
|
||||
else:
|
||||
text = text + '\n' + image
|
||||
|
||||
if visible_text == '' or visible_text is None:
|
||||
visible_text = text
|
||||
elif '<image>' in visible_text:
|
||||
visible_text = visible_text.replace('<image>', image)
|
||||
else:
|
||||
visible_text = visible_text + '\n' + image
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
rows = [f"{state['context'].strip()}\n"]
|
||||
min_rows = 3
|
||||
|
||||
# Finding the maximum prompt size
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
|
||||
prefix1 = f"{state['name1']}: "
|
||||
prefix2 = f"{state['name2']}: "
|
||||
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and LLaVAEmbedder.len_in_tokens(''.join(rows)) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||
else:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}\n")
|
||||
|
||||
string = shared.history['internal'][i][0]
|
||||
if string != '':
|
||||
rows.insert(1, f"{prefix1}{string.strip()}\n")
|
||||
|
||||
i -= 1
|
||||
|
||||
if impersonate:
|
||||
min_rows = 2
|
||||
rows.append(f"{prefix1}")
|
||||
elif not _continue:
|
||||
# Adding the user message
|
||||
if len(user_input) > 0:
|
||||
rows.append(f"{prefix1}{user_input}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions("bot_prefix", f"{prefix2}"))
|
||||
|
||||
while len(rows) > min_rows and LLaVAEmbedder.len_in_tokens(''.join(rows)) >= max_length:
|
||||
rows.pop(1)
|
||||
prompt = ''.join(rows)
|
||||
|
||||
if also_return_rows:
|
||||
return prompt, rows
|
||||
else:
|
||||
return prompt
|
||||
|
||||
|
||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||
global params
|
||||
start_ts = time.time()
|
||||
image_matches = re.finditer(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt)
|
||||
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
|
||||
|
||||
if len(images) == 0:
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state)
|
||||
print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
||||
return (prompt,
|
||||
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
||||
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
||||
|
||||
import logging
|
||||
|
||||
def ui():
|
||||
global llava_embedder
|
||||
llava_embedder = LLaVAEmbedder()
|
||||
with gr.Column():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
# I found that it doesn't deal super well with multiple images, and demo ui had a bug where it included only the last image anyway
|
||||
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
|
||||
# Prepare the input hijack
|
||||
picture_select.upload(
|
||||
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
|
||||
[picture_select],
|
||||
None
|
||||
)
|
||||
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
|
||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
||||
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
||||
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
||||
|
85
extensions/multimodal/DOCS.md
Normal file
85
extensions/multimodal/DOCS.md
Normal file
@ -0,0 +1,85 @@
|
||||
# Technical description of multimodal extension
|
||||
|
||||
## Working principle
|
||||
Multimodality extension does most of the stuff which is required for any image input:
|
||||
|
||||
- adds the UI
|
||||
- saves the images as base64 JPEGs to history
|
||||
- provides the hooks to the UI
|
||||
- if there are images in the prompt, it:
|
||||
- splits the prompt to text and image parts
|
||||
- adds image start/end markers to text parts, then encodes and embeds the text parts
|
||||
- calls the vision pipeline to embed the images
|
||||
- stitches the embeddings together, and returns them to text generation
|
||||
- loads the appropriate vision pipeline, selected either from model name, or by specifying --multimodal-pipeline parameter
|
||||
|
||||
Now, for the pipelines, they:
|
||||
|
||||
- load the required vision models
|
||||
- return some consts, for example the number of tokens taken up by image
|
||||
- and most importantly: return the embeddings for LLM, given a list of images
|
||||
|
||||
## Prompts/history
|
||||
|
||||
To save images in prompt/history, this extension is using a base64 JPEG, wrapped in a HTML tag, like so:
|
||||
```
|
||||
<img src="data:image/jpeg;base64,{img_str}">
|
||||
```
|
||||
where `{img_str}` is the actual image data. This format makes displaying them in the UI for free. Do note, that this format is required to be exactly the same, the regex used to find the images is: `<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">`.
|
||||
|
||||
## LLM input
|
||||
To describe the input, let's see it on an example prompt:
|
||||
```
|
||||
text1<image1>text2<image2>text3
|
||||
```
|
||||
where `textN` is N-th text, `<imageN>` is N-th image, in HTML format specified above.
|
||||
|
||||
**The first step is to split the prompt into image/text parts**, so we get:
|
||||
```
|
||||
['text1', '<image1>', 'text2', '<image2>', 'text3']
|
||||
```
|
||||
this is done in `MultimodalEmbedder._split_prompt(...)` function, which returns a list of `PromptPart`s - dataclasses wrapping the separate parts.
|
||||
|
||||
This function also appends the image start/end markers to text, which are provided by `AbstractMultimodalPipeline.image_start()` / `AbstractMultimodalPipeline.image_end()` functions. If image start is `<Img>`, and end is `</Img>`, this function will return:
|
||||
```
|
||||
['text1<Img>', '<image1>', '</Img>text2<Img>', '<image2>', '</Img>text3']
|
||||
```
|
||||
|
||||
**The returned prompt parts are then turned into token embeddings.**
|
||||
|
||||
First, they are modified to token IDs, for the text it is done using standard `modules.text_generation.encode()` function, and for the images the returned token IDs are changed to placeholders. The placeholder is a list of `N` times `placeholder token id`, where `N` is specified using `AbstractMultimodalPipeline.num_image_embeds()`, and placeholder token IDs using `AbstractMultimodalPipeline.placeholder_token_id()`.
|
||||
|
||||
Now, based on the token IDs, the prompt might get truncated, especially if `max_new_tokens` are unreasonably high. Unfortunately, it can't be done simply, just by trimming the prompt to be short enough. This way will lead to sometimes splitting the prompt in the middle of an image embedding, which usually breaks the generation. Therefore, in this case, the entire image needs to be removed from input. This is done inside `MultimodalEmbedder._encode_text(...)` function.
|
||||
|
||||
**After the tokenization, the tokens need to get embedded**, the text and images are once again treated separately.
|
||||
|
||||
The text parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_tokens(...)` function. It uses standard embedding function from the model, but to support many LLMs, the actual function is returned by the pipeline (as it might be different for different LLMs), for LLaMA it is `shared.model.model.embed_tokens(...)`.
|
||||
|
||||
The image parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_images(...)` function. This function is specific for a given pipeline, it takes the images as input, forwards them through vision model/projector, and returns the embeddings.
|
||||
|
||||
**Now, the returned embeddings are stitched together**, using `torch.cat()`, this is creating the final input to the LLM.
|
||||
|
||||
## Pipelines
|
||||
|
||||
All of the pipelines should subclass `AbstractMultimodalPipeline` class. The idea is to allow for new pipelines to be added in the same way as user extensions - git clone into `extensions/multimodal/pipelines`.
|
||||
|
||||
The pipelines are the description of the vision part, containing vision model/multimodal projector. All of the pipelines should have an unique `name()`, which is then selected by user, in `--multimodal-pipeline` CLI argument. For an example, see `pipelines/llava/llava.py`.
|
||||
|
||||
## Pipeline modules
|
||||
|
||||
Pipelines are organized into "pipeline modules" - subdirectories in `pipelines` directory. The pipeline modules should contain a file called `pipelines.py`, that should contain the following fields:
|
||||
- `available_pipelines: List[str]` - list of pipelines provided by this module, shown as the list of available pipelines to the user
|
||||
- `def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a concrete pipeline by `name`, if `name` doesn't match any, should return `None`. `params` is the user settings for multimodal extension
|
||||
- `def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a pipeline from `model_name`, should be eager to return `None`, unless the determination can be done clearly (for example: minigpt-4 bases on vicuna - it should never return the pipeline, but llava can, as it has its own specific LLM finetune)
|
||||
|
||||
**NOTE**: A pipeline module should lazy-import the pipelines only when necessary, and it should keep its imports to minimum
|
||||
|
||||
## Pipeline params
|
||||
|
||||
The pipelines will get the extension `params` in the constructor. They should honor the following fields:
|
||||
- `vision_device` - string, specifying `torch.device` to run the vision model (CLIP/ViT) on
|
||||
- `vision_bits` - int, number of fp bits to load the vision model(s) in
|
||||
- `projector_device` - string, specifying `torch.device` to run the projector models (Linear layers, QFormer, etc.) on
|
||||
- `projector_bits` - int, number of fp bits to load the projector models in
|
||||
|
||||
As a helper, `AbstractMultimodalPipeline` has `_get_device(self, setting_name: str, params: dict)` and `_get_dtype(self, setting_name: str, params: dict)` helper functions, which parse string/int and return `torch.device` / `torch.dtype`.
|
78
extensions/multimodal/README.md
Normal file
78
extensions/multimodal/README.md
Normal file
@ -0,0 +1,78 @@
|
||||
# Multimodal
|
||||
|
||||
## Description
|
||||
|
||||
Adds support for multimodality (text+images) to text-generation-webui.
|
||||
|
||||
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
|
||||
|
||||
## Usage
|
||||
|
||||
To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples:
|
||||
|
||||
```
|
||||
python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat
|
||||
python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat
|
||||
python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat
|
||||
python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat
|
||||
```
|
||||
|
||||
There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`:
|
||||
|
||||
- clone https://github.com/Wojtab/minigpt-4-pipeline into `extensions/multimodal/pipelines`
|
||||
- install the requirements.txt
|
||||
|
||||
The same procedure should be used to install other pipelines, which can then me used with `--multimodal-pipeline [pipeline name]`. For additional multimodal pipelines refer to compatibility section below.
|
||||
|
||||
Do note, that each image takes up a considerable amount of tokens, so adjust `max_new_tokens` to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
|
||||
|
||||
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
|
||||
|
||||
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. It seems as some multimodal networks consider the features in all images at the same time as if they were a single image. Due to this behavior, by default the extension skips previous images. However, it can lead to sub-par generation on other pipelines. If you want to include all images, just tick this checkbox.
|
||||
|
||||
## Compatibility
|
||||
As of now, the following multimodal pipelines are supported:
|
||||
|Pipeline|`--multimodal-pipeline`|Default LLM|LLM info(for the linked model)|Pipeline repository|
|
||||
|-|-|-|-|-|
|
||||
|[LLaVA 13B](https://github.com/haotian-liu/LLaVA)|`llava-13b`|[LLaVA 13B](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|
||||
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|
||||
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
||||
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
||||
|
||||
Some pipelines could support different LLMs, but do note that while it might work, it isn't a supported configuration.
|
||||
|
||||
DO NOT report bugs if you are using a different LLM.
|
||||
|
||||
DO NOT report bugs with pipelines in this repository (unless they are built-in)
|
||||
|
||||
## Extension config
|
||||
This extension uses following parameters (from settings.json):
|
||||
|Parameter|Description|
|
||||
|---------|-----------|
|
||||
|`multimodal-vision_bits`|Number of bits to load vision models (CLIP/ViT) feature extractor in (most pipelines should support either 32 or 16, default=32)|
|
||||
|`multimodal-vision_device`|Torch device to run the feature extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`multimodal-projector_bits`|Number of bits to load feature projector model(s) in (most pipelines should support either 32 or 16, default=32)|
|
||||
|`multimodal-projector_device`|Torch device to run the feature projector model(s) on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`multimodal-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
|
||||
|
||||
## Usage through API
|
||||
|
||||
You can run the multimodal inference through API, by inputting the images to prompt. Images are embedded like so: `f'<img src="data:image/jpeg;base64,{img_str}">'`, where `img_str` is base-64 jpeg data. Python example:
|
||||
```Python
|
||||
import base64
|
||||
import requests
|
||||
|
||||
CONTEXT = "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.### Human: Hi!### Assistant: Hi there! How can I help you today?\n"
|
||||
|
||||
with open('extreme_ironing.jpg', 'rb') as f:
|
||||
img_str = base64.b64encode(f.read()).decode('utf-8')
|
||||
prompt = CONTEXT + f'### Human: What is unusual about this image: \n<img src="data:image/jpeg;base64,{img_str}">### Assistant: '
|
||||
print(requests.post('http://127.0.0.1:5000/api/v1/generate', json={'prompt': prompt, 'stopping_strings': ['\n###']}).json())
|
||||
```
|
||||
script output:
|
||||
```Python
|
||||
{'results': [{'text': "The unusual aspect of this image is that a man is standing on top of a yellow minivan while doing his laundry. He has set up a makeshift clothes line using the car's rooftop as an outdoor drying area. This scene is uncommon because people typically do their laundry indoors, in a dedicated space like a laundromat or a room in their home, rather than on top of a moving vehicle. Additionally, hanging clothes on the car could be potentially hazardous or illegal in some jurisdictions due to the risk of damaging the vehicle or causing accidents on the road.\n##"}]}
|
||||
```
|
||||
|
||||
## For pipeline developers/technical description
|
||||
see [DOCS.md](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/multimodal/DOCS.md)
|
62
extensions/multimodal/abstract_pipeline.py
Normal file
62
extensions/multimodal/abstract_pipeline.py
Normal file
@ -0,0 +1,62 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class AbstractMultimodalPipeline(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def name() -> str:
|
||||
'name of the pipeline, should be same as in --multimodal-pipeline'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def image_start() -> Optional[str]:
|
||||
'return image start string, string representation of image start token, or None if not applicable'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def image_end() -> Optional[str]:
|
||||
'return image end string, string representation of image end token, or None if not applicable'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def placeholder_token_id() -> int:
|
||||
'return placeholder token id'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def num_image_embeds() -> int:
|
||||
'return the number of embeds used by a single image (for example: 256 for LLaVA)'
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
'forward the images through vision pipeline, and return their embeddings'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
'embed tokens, the exact function varies by LLM, for LLaMA it is `shared.model.model.embed_tokens`'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
'get placeholder embeddings if there are multiple images, and `add_all_images_to_prompt` is False'
|
||||
pass
|
||||
|
||||
def _get_device(self, setting_name: str, params: dict):
|
||||
if params[setting_name] is None:
|
||||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(params[setting_name])
|
||||
|
||||
def _get_dtype(self, setting_name: str, params: dict):
|
||||
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
|
177
extensions/multimodal/multimodal_embedder.py
Normal file
177
extensions/multimodal/multimodal_embedder.py
Normal file
@ -0,0 +1,177 @@
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from extensions.multimodal.pipeline_loader import load_pipeline
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, get_max_prompt_length
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptPart:
|
||||
text: str
|
||||
image: Optional[Image.Image] = None
|
||||
is_image: bool = False
|
||||
input_ids: Optional[torch.Tensor] = None
|
||||
embedding: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class MultimodalEmbedder:
|
||||
def __init__(self, params: dict):
|
||||
pipeline, source = load_pipeline(params)
|
||||
self.pipeline = pipeline
|
||||
logging.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})')
|
||||
|
||||
def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]:
|
||||
"""Splits a prompt into a list of `PromptParts` to separate image data from text.
|
||||
It will also append `image_start` and `image_end` before and after the image, and optionally parse and load the images,
|
||||
if `load_images` is `True`.
|
||||
"""
|
||||
parts: List[PromptPart] = []
|
||||
curr = 0
|
||||
while True:
|
||||
match = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt[curr:])
|
||||
if match is None:
|
||||
# no more image tokens, append the rest of the prompt
|
||||
if curr > 0:
|
||||
# add image end token after last image
|
||||
parts.append(PromptPart(text=self.pipeline.image_end() + prompt[curr:]))
|
||||
else:
|
||||
parts.append(PromptPart(text=prompt))
|
||||
break
|
||||
# found an image, append image start token to the text
|
||||
if match.start() > 0:
|
||||
parts.append(PromptPart(text=prompt[curr:curr+match.start()]+self.pipeline.image_start()))
|
||||
else:
|
||||
parts.append(PromptPart(text=self.pipeline.image_start()))
|
||||
# append the image
|
||||
parts.append(PromptPart(
|
||||
text=match.group(0),
|
||||
image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None,
|
||||
is_image=True
|
||||
))
|
||||
curr += match.end()
|
||||
return parts
|
||||
|
||||
def _len_in_tokens_prompt_parts(self, parts: List[PromptPart]) -> int:
|
||||
"""Total length in tokens of all `parts`"""
|
||||
tokens = 0
|
||||
for part in parts:
|
||||
if part.is_image:
|
||||
tokens += self.pipeline.num_image_embeds()
|
||||
elif part.input_ids is not None:
|
||||
tokens += len(part.input_ids)
|
||||
else:
|
||||
tokens += len(encode(part.text)[0])
|
||||
return tokens
|
||||
|
||||
def len_in_tokens(self, prompt: str) -> int:
|
||||
"""Total length in tokens for a given text `prompt`"""
|
||||
parts = self._split_prompt(prompt, False)
|
||||
return self._len_in_tokens_prompt_parts(parts)
|
||||
|
||||
def _encode_single_text(self, part: PromptPart, add_bos_token: bool) -> PromptPart:
|
||||
"""Encode a single prompt `part` to `input_ids`. Returns a `PromptPart`"""
|
||||
if part.is_image:
|
||||
placeholders = torch.ones((self.pipeline.num_image_embeds())) * self.pipeline.placeholder_token_id()
|
||||
part.input_ids = placeholders.to(shared.model.device, dtype=torch.int64)
|
||||
else:
|
||||
part.input_ids = encode(part.text, add_bos_token=add_bos_token)[0].to(shared.model.device, dtype=torch.int64)
|
||||
return part
|
||||
|
||||
@staticmethod
|
||||
def _num_images(parts: List[PromptPart]) -> int:
|
||||
count = 0
|
||||
for part in parts:
|
||||
if part.is_image:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]:
|
||||
"""Encode text to token_ids, also truncate the prompt, if necessary.
|
||||
|
||||
The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set
|
||||
such that the context + min_rows don't fit, we can get a prompt which is too long.
|
||||
We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user
|
||||
"""
|
||||
encoded: List[PromptPart] = []
|
||||
for i, part in enumerate(parts):
|
||||
encoded.append(self._encode_single_text(part, i==0 and state['add_bos_token']))
|
||||
|
||||
# truncation:
|
||||
max_len = get_max_prompt_length(state)
|
||||
removed_images = 0
|
||||
|
||||
# 1. remove entire text/image blocks
|
||||
while self._len_in_tokens_prompt_parts(encoded[1:]) > max_len:
|
||||
if encoded[0].is_image:
|
||||
removed_images += 1
|
||||
encoded = encoded[1:]
|
||||
|
||||
# 2. check if the last prompt part doesn't need to get truncated
|
||||
if self._len_in_tokens_prompt_parts(encoded) > max_len:
|
||||
if encoded[0].is_image:
|
||||
# don't truncate image embeddings, just remove the image, otherwise generation will be broken
|
||||
removed_images += 1
|
||||
encoded = encoded[1:]
|
||||
elif len(encoded) > 1 and encoded[0].text.endswith(self.pipeline.image_start()):
|
||||
# see if we can keep image_start token
|
||||
len_image_start = len(encode(self.pipeline.image_start(), add_bos_token=state['add_bos_token'])[0])
|
||||
if self._len_in_tokens_prompt_parts(encoded[1:]) + len_image_start > max_len:
|
||||
# we can't -> remove this text, and the image
|
||||
encoded = encoded[2:]
|
||||
removed_images += 1
|
||||
else:
|
||||
# we can -> just truncate the text
|
||||
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
|
||||
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
|
||||
elif len(encoded) > 0:
|
||||
# only one text left, truncate it normally
|
||||
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
|
||||
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
|
||||
|
||||
# notify user if we truncated an image
|
||||
if removed_images > 0:
|
||||
logging.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken")
|
||||
|
||||
return encoded
|
||||
|
||||
def _embed(self, parts: List[PromptPart]) -> List[PromptPart]:
|
||||
# batch images
|
||||
image_indicies = [i for i, part in enumerate(parts) if part.is_image]
|
||||
embedded = self.pipeline.embed_images([parts[i].image for i in image_indicies])
|
||||
for i, embeds in zip(image_indicies, embedded):
|
||||
parts[i].embedding = embeds
|
||||
# embed text
|
||||
for (i, part) in enumerate(parts):
|
||||
if not part.is_image:
|
||||
parts[i].embedding = self.pipeline.embed_tokens(part.input_ids)
|
||||
return parts
|
||||
|
||||
def _remove_old_images(self, parts: List[PromptPart], params: dict) -> List[PromptPart]:
|
||||
if params['add_all_images_to_prompt']:
|
||||
return parts
|
||||
already_added = False
|
||||
for i, part in reversed(list(enumerate(parts))):
|
||||
if part.is_image:
|
||||
if already_added:
|
||||
parts[i].embedding = self.pipeline.placeholder_embeddings()
|
||||
else:
|
||||
already_added = True
|
||||
return parts
|
||||
|
||||
def forward(self, prompt: str, state: Any, params: dict):
|
||||
prompt_parts = self._split_prompt(prompt, True)
|
||||
prompt_parts = self._encode_text(state, prompt_parts)
|
||||
prompt_parts = self._embed(prompt_parts)
|
||||
prompt_parts = self._remove_old_images(prompt_parts, params)
|
||||
embeds = tuple(part.embedding for part in prompt_parts)
|
||||
ids = tuple(part.input_ids for part in prompt_parts)
|
||||
input_embeds = torch.cat(embeds, dim=0)
|
||||
input_ids = torch.cat(ids, dim=0)
|
||||
return prompt, input_ids, input_embeds, self._num_images(prompt_parts)
|
52
extensions/multimodal/pipeline_loader.py
Normal file
52
extensions/multimodal/pipeline_loader.py
Normal file
@ -0,0 +1,52 @@
|
||||
import logging
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
from modules import shared
|
||||
|
||||
|
||||
def _get_available_pipeline_modules():
|
||||
pipeline_path = Path(__file__).parent / 'pipelines'
|
||||
modules = [p for p in pipeline_path.iterdir() if p.is_dir()]
|
||||
return [m.name for m in modules if (m / 'pipelines.py').exists()]
|
||||
|
||||
|
||||
def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
|
||||
pipeline_modules = {}
|
||||
available_pipeline_modules = _get_available_pipeline_modules()
|
||||
for name in available_pipeline_modules:
|
||||
try:
|
||||
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
|
||||
except:
|
||||
logging.warning(f'Failed to get multimodal pipelines from {name}')
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
if shared.args.multimodal_pipeline is not None:
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'get_pipeline'):
|
||||
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
|
||||
if pipeline is not None:
|
||||
return (pipeline, k)
|
||||
else:
|
||||
model_name = shared.args.model.lower()
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'get_pipeline_from_model_name'):
|
||||
pipeline = getattr(pipeline_modules[k], 'get_pipeline_from_model_name')(model_name, params)
|
||||
if pipeline is not None:
|
||||
return (pipeline, k)
|
||||
|
||||
available = []
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'available_pipelines'):
|
||||
pipelines = getattr(pipeline_modules[k], 'available_pipelines')
|
||||
available += pipelines
|
||||
|
||||
if shared.args.multimodal_pipeline is not None:
|
||||
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
|
||||
else:
|
||||
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
|
||||
logging.critical(f'{log} Please specify a correct pipeline, or disable the extension')
|
||||
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')
|
9
extensions/multimodal/pipelines/llava/README.md
Normal file
9
extensions/multimodal/pipelines/llava/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
## LLaVA pipeline
|
||||
|
||||
This module provides 2 pipelines:
|
||||
- `llava-7b` - for use with LLaVA v0 7B model (finetuned LLaMa 7B)
|
||||
- `llava-13b` - for use with LLaVA v0 13B model (finetuned LLaMa 13B)
|
||||
|
||||
[LLaVA](https://github.com/haotian-liu/LLaVA) uses CLIP `openai/clip-vit-large-patch14` as the vision model, and then a single linear layer. For 13B the projector weights are in `liuhaotian/LLaVA-13b-delta-v0`, and for 7B they are in `liuhaotian/LLaVA-7b-delta-v0`.
|
||||
|
||||
The supported parameter combinations for both the vision model, and the projector are: CUDA/32bit, CUDA/16bit, CPU/32bit
|
139
extensions/multimodal/pipelines/llava/llava.py
Normal file
139
extensions/multimodal/pipelines/llava/llava.py
Normal file
@ -0,0 +1,139 @@
|
||||
import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared
|
||||
from modules.text_generation import encode
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
|
||||
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
||||
CLIP_REPO = "openai/clip-vit-large-patch14"
|
||||
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__()
|
||||
self.clip_device = self._get_device("vision_device", params)
|
||||
self.clip_dtype = self._get_dtype("vision_bits", params)
|
||||
self.projector_device = self._get_device("projector_device", params)
|
||||
self.projector_dtype = self._get_dtype("projector_bits", params)
|
||||
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
|
||||
|
||||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
logging.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
logging.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
|
||||
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
|
||||
projector_data = torch.load(projector_path)
|
||||
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector = mm_projector.to(self.projector_device)
|
||||
|
||||
logging.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
return image_processor, vision_tower, mm_projector
|
||||
|
||||
@staticmethod
|
||||
def image_start() -> str:
|
||||
return "<im_start>"
|
||||
|
||||
@staticmethod
|
||||
def image_end() -> str:
|
||||
return "<im_end>"
|
||||
|
||||
@staticmethod
|
||||
def num_image_embeds() -> int:
|
||||
return 256
|
||||
|
||||
@staticmethod
|
||||
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
||||
|
||||
@staticmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
return LLaVA_v0_Pipeline.embed_tokens(encode("<im_patch>"*256, add_bos_token=False)[0])
|
||||
|
||||
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
images = self.image_processor(images, return_tensors='pt')['pixel_values']
|
||||
images = images.to(self.clip_device, dtype=self.clip_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
|
||||
select_hidden_state_layer = -2
|
||||
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
||||
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
|
||||
image_features = self.mm_projector(image_features)
|
||||
return image_features.to(shared.model.device, dtype=shared.model.dtype)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_repo() -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_filename() -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
|
||||
class LLaVA_v0_13B_Pipeline(LLaVA_v0_Pipeline):
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__(params)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "llava-13b"
|
||||
|
||||
@staticmethod
|
||||
def placeholder_token_id() -> int:
|
||||
return 32000
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
return (1024, 5120)
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_filename() -> str:
|
||||
return "mm_projector.bin"
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/LLaVA-13b-delta-v0"
|
||||
|
||||
|
||||
class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__(params)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "llava-7b"
|
||||
|
||||
@staticmethod
|
||||
def placeholder_token_id() -> int:
|
||||
return 32001
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
return (1024, 4096)
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_filename() -> str:
|
||||
return "mm_projector.bin"
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/LLaVA-7b-delta-v0"
|
27
extensions/multimodal/pipelines/llava/pipelines.py
Normal file
27
extensions/multimodal/pipelines/llava/pipelines.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import Optional
|
||||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
|
||||
available_pipelines = ['llava-7b', 'llava-13b']
|
||||
|
||||
|
||||
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
if name == 'llava-7b':
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
if name == 'llava-13b':
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
return None
|
||||
|
||||
|
||||
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
if 'llava' not in model_name.lower():
|
||||
return None
|
||||
if '7b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
if '13b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
return None
|
103
extensions/multimodal/script.py
Normal file
103
extensions/multimodal/script.py
Normal file
@ -0,0 +1,103 @@
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||
from modules import shared
|
||||
|
||||
params = {
|
||||
"add_all_images_to_prompt": False,
|
||||
# device to run vision encoder on
|
||||
"vision_device": None,
|
||||
# bits to load vision encoder in, either 16 or 32
|
||||
"vision_bits": 32,
|
||||
# device to run multimodal projector on
|
||||
"projector_device": None,
|
||||
# multimodal projector bits, either 32 or 16
|
||||
"projector_bits": 32
|
||||
}
|
||||
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation
|
||||
input_hijack = {
|
||||
'state': False,
|
||||
'value': ["", ""]
|
||||
}
|
||||
|
||||
|
||||
# initialized in ui, so that params are loaded from settings
|
||||
multimodal_embedder: MultimodalEmbedder = None
|
||||
|
||||
|
||||
def add_chat_picture(picture, text, visible_text):
|
||||
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
|
||||
max_hw, min_hw = max(picture.size), min(picture.size)
|
||||
aspect_ratio = max_hw / min_hw
|
||||
shortest_edge = int(max(300 / aspect_ratio, 224))
|
||||
longest_edge = int(shortest_edge * aspect_ratio)
|
||||
w = shortest_edge if picture.width < picture.height else longest_edge
|
||||
h = shortest_edge if picture.width >= picture.height else longest_edge
|
||||
picture = picture.resize((w,h))
|
||||
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
|
||||
|
||||
if '<image>' in text:
|
||||
text = text.replace('<image>', image)
|
||||
else:
|
||||
text = text + '\n' + image
|
||||
|
||||
if visible_text == '' or visible_text is None:
|
||||
visible_text = text
|
||||
elif '<image>' in visible_text:
|
||||
visible_text = visible_text.replace('<image>', image)
|
||||
else:
|
||||
visible_text = visible_text + '\n' + image
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def custom_tokenized_length(prompt):
|
||||
return multimodal_embedder.len_in_tokens(prompt)
|
||||
|
||||
|
||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||
global params
|
||||
start_ts = time.time()
|
||||
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
|
||||
|
||||
if image_match is None:
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
||||
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
||||
return (prompt,
|
||||
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
||||
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
||||
|
||||
|
||||
def ui():
|
||||
global multimodal_embedder
|
||||
multimodal_embedder = MultimodalEmbedder(params)
|
||||
with gr.Column():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
# The models don't seem to deal well with multiple images
|
||||
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
|
||||
# Prepare the input hijack
|
||||
picture_select.upload(
|
||||
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
|
||||
[picture_select],
|
||||
None
|
||||
)
|
||||
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
|
||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
@ -14,7 +14,7 @@ from PIL import Image
|
||||
import modules.shared as shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import chat_html_wrapper, make_thumbnail
|
||||
from modules.text_generation import (encode, generate_reply,
|
||||
from modules.text_generation import (generate_reply, get_encoded_length,
|
||||
get_max_prompt_length)
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
|
||||
# Building the prompt
|
||||
i = len(history) - 1
|
||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||
while i >= 0 and get_encoded_length(''.join(rows)) < max_length:
|
||||
if _continue and i == len(history) - 1:
|
||||
rows.insert(1, bot_turn_stripped + history[i][1].strip())
|
||||
else:
|
||||
@ -90,7 +90,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
||||
|
||||
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||
while len(rows) > min_rows and get_encoded_length(''.join(rows)) >= max_length:
|
||||
rows.pop(1)
|
||||
|
||||
prompt = ''.join(rows)
|
||||
|
@ -7,6 +7,7 @@ import gradio as gr
|
||||
import extensions
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
state = {}
|
||||
available_extensions = []
|
||||
setup_called = set()
|
||||
@ -73,15 +74,12 @@ def _apply_input_hijack(text, visible_text):
|
||||
return text, visible_text
|
||||
|
||||
|
||||
# custom_generate_chat_prompt handling
|
||||
# custom_generate_chat_prompt handling - currently only the first one will work
|
||||
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in iterator():
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
|
||||
if custom_generate_chat_prompt is not None:
|
||||
return custom_generate_chat_prompt(text, state, **kwargs)
|
||||
if hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
return custom_generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
return None
|
||||
|
||||
@ -95,16 +93,26 @@ def _apply_state_modifier_extensions(state):
|
||||
return state
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output
|
||||
# Extension functions that override the default tokenizer output - currently only the first one will work
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
# Custom generate reply handling
|
||||
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
|
||||
# currently only the first one will work
|
||||
def _apply_custom_tokenized_length(prompt):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'custom_tokenized_length'):
|
||||
return getattr(extension, 'custom_tokenized_length')(prompt)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Custom generate reply handling - currently only the first one will work
|
||||
def _apply_custom_generate_reply():
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'custom_generate_reply'):
|
||||
@ -121,7 +129,8 @@ EXTENSION_MAP = {
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
||||
"custom_generate_reply": _apply_custom_generate_reply
|
||||
"custom_generate_reply": _apply_custom_generate_reply,
|
||||
"tokenized_length": _apply_custom_tokenized_length
|
||||
}
|
||||
|
||||
|
||||
|
@ -166,6 +166,8 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
|
||||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||
|
||||
# Multimodal
|
||||
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
|
||||
|
||||
args = parser.parse_args()
|
||||
args_defaults = parser.parse_args([])
|
||||
@ -183,12 +185,21 @@ if args.trust_remote_code:
|
||||
if args.share:
|
||||
logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
|
||||
|
||||
|
||||
def add_extension(name):
|
||||
if args.extensions is None:
|
||||
args.extensions = [name]
|
||||
elif 'api' not in args.extensions:
|
||||
args.extensions.append(name)
|
||||
|
||||
|
||||
# Activating the API extension
|
||||
if args.api or args.public_api:
|
||||
if args.extensions is None:
|
||||
args.extensions = ['api']
|
||||
elif 'api' not in args.extensions:
|
||||
args.extensions.append('api')
|
||||
add_extension('api')
|
||||
|
||||
# Activating the multimodal extension
|
||||
if args.multimodal_pipeline is not None:
|
||||
add_extension('multimodal')
|
||||
|
||||
|
||||
def is_chat():
|
||||
|
@ -59,6 +59,14 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||
return input_ids.cuda()
|
||||
|
||||
|
||||
def get_encoded_length(prompt):
|
||||
length_after_extensions = apply_extensions('tokenized_length', prompt)
|
||||
if length_after_extensions is not None:
|
||||
return length_after_extensions
|
||||
|
||||
return len(encode(prompt)[0])
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from modules import chat, shared, training, ui, utils
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt, unload_model
|
||||
from modules.text_generation import encode, generate_reply, stop_everything_event
|
||||
from modules.text_generation import generate_reply, get_encoded_length, stop_everything_event
|
||||
|
||||
|
||||
def load_model_wrapper(selected_model, autoload=False):
|
||||
@ -140,7 +140,7 @@ def load_prompt(fname):
|
||||
|
||||
|
||||
def count_tokens(text):
|
||||
tokens = len(encode(text)[0])
|
||||
tokens = get_encoded_length(text)
|
||||
return f'{tokens} tokens in the input.'
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user