104 lines
3.6 KiB
Python
Raw Normal View History

import base64
import logging
import re
import time
from functools import partial
from io import BytesIO
import gradio as gr
import torch
2023-05-09 22:49:39 -03:00
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared
params = {
"add_all_images_to_prompt": False,
# device to run vision encoder on
"vision_device": None,
# bits to load vision encoder in, either 16 or 32
"vision_bits": 32,
# device to run multimodal projector on
"projector_device": None,
# multimodal projector bits, either 32 or 16
"projector_bits": 32
}
# If 'state' is True, will hijack the next chat generation
input_hijack = {
'state': False,
'value': ["", ""]
}
# initialized in ui, so that params are loaded from settings
multimodal_embedder: MultimodalEmbedder = None
def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(max(300 / aspect_ratio, 224))
longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge
2023-05-09 20:20:35 -03:00
picture = picture.resize((w, h))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text:
text = text.replace('<image>', image)
else:
text = text + '\n' + image
if visible_text == '' or visible_text is None:
visible_text = text
elif '<image>' in visible_text:
visible_text = visible_text.replace('<image>', image)
else:
visible_text = visible_text + '\n' + image
return text, visible_text
def custom_tokenized_length(prompt):
return multimodal_embedder.len_in_tokens(prompt)
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
global params
start_ts = time.time()
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
if image_match is None:
return prompt, input_ids, input_embeds
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return (prompt,
2023-05-09 20:20:35 -03:00
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
def ui():
global multimodal_embedder
multimodal_embedder = MultimodalEmbedder(params)
with gr.Column():
picture_select = gr.Image(label='Send a picture', type='pil')
# The models don't seem to deal well with multiple images
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
# Prepare the input hijack
picture_select.upload(
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
[picture_select],
None
)
2023-05-09 20:20:35 -03:00
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select)