diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index 14e9b641..034a4dd1 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -2,14 +2,13 @@ import base64 from io import BytesIO import gradio as gr +import modules.chat as chat +import modules.shared as shared import torch from transformers import BlipForConditionalGeneration, BlipProcessor -import modules.chat as chat -import modules.shared as shared - # If 'state' is True, will hijack the next chat generation with -# custom input text +# custom input text given by 'value' in the format [text, visible_text] input_hijack = { 'state': False, 'value': ["", ""] @@ -31,34 +30,16 @@ def generate_chat_picture(picture, name1, name2): visible_text = f'' return text, visible_text -def input_modifier(string): - """ - This function is applied to your text inputs before - they are fed into the model. - """ - - return string - -def output_modifier(string): - """ - This function is applied to the model outputs. - """ - - return string - -def bot_prefix_modifier(string): - """ - This function is only applied in chat mode. It modifies - the prefix text for the Bot and can be used to bias its - behavior. - """ - - return string - def ui(): picture_select = gr.Image(label='Send a picture', type='pil') function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' + + # Prepare the hijack with custom inputs picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None) + + # Call the generation function picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) + + # Clear the picture from the upload field picture_select.upload(lambda : None, [], [picture_select], show_progress=False) diff --git a/modules/chat.py b/modules/chat.py index 2cbc5c8e..3c8b4dbf 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -8,12 +8,13 @@ from pathlib import Path from PIL import Image -import modules.shared as shared import modules.extensions as extensions_module +import modules.shared as shared from modules.extensions import apply_extensions from modules.html_generator import generate_chat_html from modules.text_generation import encode, generate_reply, get_max_prompt_length + # This gets the new line characters right. def clean_chat_message(text): text = text.replace('\n', '\n\n')