From 7a527a55818c591af899809bac0a976457d07e55 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sat, 25 Feb 2023 00:23:51 -0300
Subject: [PATCH] Move "send picture" into an extension
I am not proud of how I did it for now.
---
modules/chat.py | 38 ++++++++++++++++----------------------
modules/extensions.py | 2 --
modules/shared.py | 4 +++-
server.py | 39 +++++++++++++++------------------------
4 files changed, 34 insertions(+), 49 deletions(-)
diff --git a/modules/chat.py b/modules/chat.py
index 0e6318f6..140d2ed8 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -4,19 +4,16 @@ import io
import json
import re
from datetime import datetime
-from io import BytesIO
from pathlib import Path
from PIL import Image
import modules.shared as shared
+import modules.extensions as extensions_module
from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html
from modules.text_generation import encode, generate_reply, get_max_prompt_length
-if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
- import modules.bot_picture as bot_picture
-
# This gets the new line characters right.
def clean_chat_message(text):
text = text.replace('\n', '\n\n')
@@ -84,18 +81,10 @@ def extract_message_from_reply(question, reply, current, other, check, extension
return reply, next_character_found, substring_found
-def generate_chat_picture(picture, name1, name2):
- text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*'
- buffer = BytesIO()
- picture.save(buffer, format="JPEG")
- img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
- visible_text = f''
- return text, visible_text
-
def stop_everything_event():
shared.stop_everything = True
-def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
+def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
shared.stop_everything = False
just_started = True
eos_token = '\n' if check else None
@@ -103,13 +92,18 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
- if shared.args.picture and picture is not None:
- text, visible_text = generate_chat_picture(picture, name1, name2)
- else:
+ # Hijacking the input using an extension
+ visible_text = None
+ for extension, _ in extensions_module.iterator():
+ if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
+ text, visible_text = extension.input_hijack['value']
+ extension.input_hijack['state'] = False
+
+ if visible_text is None:
visible_text = text
if shared.args.chat:
visible_text = visible_text.replace('\n', '
')
- text = apply_extensions(text, "input")
+ text = apply_extensions(text, "input")
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
# Generate
@@ -138,7 +132,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
break
yield shared.history['visible']
-def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
+def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
eos_token = '\n' if check else None
if 'pygmalion' in shared.model_name.lower():
@@ -154,11 +148,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
break
yield reply
-def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
- for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
+def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
+ for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
yield generate_chat_html(_history, name1, name2, shared.character)
-def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
+def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
if shared.character != 'None' and len(shared.history['visible']) == 1:
if shared.args.cai_chat:
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
@@ -168,7 +162,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop()
- for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
+ for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
diff --git a/modules/extensions.py b/modules/extensions.py
index 6af7c58a..c0da496a 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,5 +1,3 @@
-import gradio as gr
-
import extensions
import modules.shared as shared
diff --git a/modules/shared.py b/modules/shared.py
index d4ffc19d..7b87d285 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,6 +14,9 @@ stop_everything = False
# UI elements (buttons, sliders, HTML, etc)
gradio = {}
+# Generation input parameters
+input_params = []
+
settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
@@ -40,7 +43,6 @@ parser.add_argument('--model', type=str, help='Name of the model to load by defa
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
-parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
diff --git a/server.py b/server.py
index 5db441b6..32a1d2e4 100644
--- a/server.py
+++ b/server.py
@@ -185,7 +185,6 @@ else:
if shared.args.chat or shared.args.cai_chat:
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as shared.gradio['interface']:
- shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
else:
@@ -202,10 +201,6 @@ if shared.args.chat or shared.args.cai_chat:
with gr.Row():
shared.gradio['Send last reply to input'] = gr.Button('Send last reply to input')
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
- if shared.args.picture:
- with gr.Row():
- shared.gradio['picture_select'] = gr.Image(label='Send a picture', type='pil')
-
with gr.Tab('Chat settings'):
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
@@ -246,23 +241,19 @@ if shared.args.chat or shared.args.cai_chat:
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
with gr.Column():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
-
create_settings_menus()
+
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider']]
if shared.args.extensions is not None:
with gr.Tab('Extensions'):
extensions_module.create_extensions_block()
- input_params = [shared.gradio[i] for i in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider']]
- if shared.args.picture:
- input_params.append(shared.gradio['picture_select'])
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
- gen_events.append(shared.gradio['Generate'].click(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
- gen_events.append(shared.gradio['textbox'].submit(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
- if shared.args.picture:
- shared.gradio['picture_select'].upload(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
- gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
- gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
+ gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
shared.gradio['Send last reply to input'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
@@ -284,13 +275,13 @@ if shared.args.chat or shared.args.cai_chat:
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
- if shared.args.picture:
- shared.gradio['picture_select'].upload(lambda : None, [], [shared.gradio['picture_select']], show_progress=False)
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
+
+ shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook:
@@ -311,10 +302,10 @@ elif shared.args.notebook:
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
- input_params = [shared.gradio[k] for k in ('textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping')]
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
- gen_events.append(shared.gradio['Generate'].click(generate_reply, input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
- gen_events.append(shared.gradio['textbox'].submit(generate_reply, input_params, output_params, show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+ gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
else:
@@ -343,11 +334,11 @@ else:
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
- input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
- gen_events.append(shared.gradio['Generate'].click(generate_reply, input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
- gen_events.append(shared.gradio['textbox'].submit(generate_reply, input_params, output_params, show_progress=shared.args.no_stream))
- gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + input_params[1:], output_params, show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+ gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['interface'].queue()