import base64 import re import time from functools import partial from io import BytesIO import gradio as gr import torch from extensions.multimodal.multimodal_embedder import MultimodalEmbedder from modules import shared from modules.logging_colors import logger params = { "add_all_images_to_prompt": False, # device to run vision encoder on "vision_device": None, # bits to load vision encoder in, either 16 or 32 "vision_bits": 32, # device to run multimodal projector on "projector_device": None, # multimodal projector bits, either 32 or 16 "projector_bits": 32 } # If 'state' is True, will hijack the next chat generation input_hijack = { 'state': False, 'value': ["", ""] } # initialized in ui, so that params are loaded from settings multimodal_embedder: MultimodalEmbedder = None def add_chat_picture(picture, text, visible_text): # resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable) max_hw, min_hw = max(picture.size), min(picture.size) aspect_ratio = max_hw / min_hw shortest_edge = int(max(300 / aspect_ratio, 224)) longest_edge = int(shortest_edge * aspect_ratio) w = shortest_edge if picture.width < picture.height else longest_edge h = shortest_edge if picture.width >= picture.height else longest_edge picture = picture.resize((w, h)) buffer = BytesIO() picture.save(buffer, format="JPEG") img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') image = f'<img src="data:image/jpeg;base64,{img_str}">' if '<image>' in text: text = text.replace('<image>', image) else: text = text + '\n' + image if visible_text == '' or visible_text is None: visible_text = text elif '<image>' in visible_text: visible_text = visible_text.replace('<image>', image) else: visible_text = visible_text + '\n' + image return text, visible_text def custom_tokenized_length(prompt): return multimodal_embedder.len_in_tokens(prompt) def tokenizer_modifier(state, prompt, input_ids, input_embeds): global params start_ts = time.time() image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt) if image_match is None: return prompt, input_ids, input_embeds prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params) logger.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') return (prompt, input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) def ui(): global multimodal_embedder multimodal_embedder = MultimodalEmbedder(params) with gr.Column(): picture_select = gr.Image(label='Send a picture', type='pil') # The models don't seem to deal well with multiple images single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one') # Prepare the input hijack picture_select.upload( lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}), [picture_select], None ) picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None) single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) shared.gradio['Generate'].click(lambda: None, None, picture_select) shared.gradio['textbox'].submit(lambda: None, None, picture_select)