2023-05-09 22:49:39 -03:00

104 lines
3.6 KiB
Python

import base64
import logging
import re
import time
from functools import partial
from io import BytesIO
import gradio as gr
import torch
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared
params = {
"add_all_images_to_prompt": False,
# device to run vision encoder on
"vision_device": None,
# bits to load vision encoder in, either 16 or 32
"vision_bits": 32,
# device to run multimodal projector on
"projector_device": None,
# multimodal projector bits, either 32 or 16
"projector_bits": 32
}
# If 'state' is True, will hijack the next chat generation
input_hijack = {
'state': False,
'value': ["", ""]
}
# initialized in ui, so that params are loaded from settings
multimodal_embedder: MultimodalEmbedder = None
def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(max(300 / aspect_ratio, 224))
longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w, h))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text:
text = text.replace('<image>', image)
else:
text = text + '\n' + image
if visible_text == '' or visible_text is None:
visible_text = text
elif '<image>' in visible_text:
visible_text = visible_text.replace('<image>', image)
else:
visible_text = visible_text + '\n' + image
return text, visible_text
def custom_tokenized_length(prompt):
return multimodal_embedder.len_in_tokens(prompt)
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
global params
start_ts = time.time()
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
if image_match is None:
return prompt, input_ids, input_embeds
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
def ui():
global multimodal_embedder
multimodal_embedder = MultimodalEmbedder(params)
with gr.Column():
picture_select = gr.Image(label='Send a picture', type='pil')
# The models don't seem to deal well with multiple images
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
# Prepare the input hijack
picture_select.upload(
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
[picture_select],
None
)
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select)