mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-30 22:20:14 +01:00
5389fce8e1
Reworked remove_surrounded_chars() to use regular expression ( https://regexr.com/7alb5 ) instead of repeated string concatenations for elevenlab_tts, silero_tts, sd_api_pictures. This should be both faster and more robust in handling asterisks. Reduced the memory footprint of send_pictures and sd_api_pictures by scaling the images in the chat to 300 pixels max-side wise. (The user already has the original in case of the sent picture and there's an option to save the SD generation). This should fix history growing annoyingly large with multiple pictures present
58 lines
2.4 KiB
Python
58 lines
2.4 KiB
Python
import base64
|
|
from io import BytesIO
|
|
|
|
import gradio as gr
|
|
import torch
|
|
from transformers import BlipForConditionalGeneration, BlipProcessor
|
|
from PIL import Image
|
|
|
|
import modules.chat as chat
|
|
import modules.shared as shared
|
|
|
|
# If 'state' is True, will hijack the next chat generation with
|
|
# custom input text given by 'value' in the format [text, visible_text]
|
|
input_hijack = {
|
|
'state': False,
|
|
'value': ["", ""]
|
|
}
|
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
|
|
|
def caption_image(raw_image):
|
|
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
|
out = model.generate(**inputs, max_new_tokens=100)
|
|
return processor.decode(out[0], skip_special_tokens=True)
|
|
|
|
def generate_chat_picture(picture, name1, name2):
|
|
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
|
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
|
width, height = picture.size
|
|
if (width > 300):
|
|
height = int(height * (300 / width))
|
|
width = 300
|
|
elif (height > 300):
|
|
width = int(width * (300 / height))
|
|
height = 300
|
|
newsize = (width, height)
|
|
picture = picture.resize(newsize, Image.LANCZOS)
|
|
buffer = BytesIO()
|
|
picture.save(buffer, format="JPEG")
|
|
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
|
return text, visible_text
|
|
|
|
def ui():
|
|
picture_select = gr.Image(label='Send a picture', type='pil')
|
|
|
|
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
|
|
|
# Prepare the hijack with custom inputs
|
|
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
|
|
|
|
# Call the generation function
|
|
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
|
|
|
# Clear the picture from the upload field
|
|
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
|