From 7a527a55818c591af899809bac0a976457d07e55 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Feb 2023 00:23:51 -0300 Subject: [PATCH] Move "send picture" into an extension I am not proud of how I did it for now. --- modules/chat.py | 38 ++++++++++++++++---------------------- modules/extensions.py | 2 -- modules/shared.py | 4 +++- server.py | 39 +++++++++++++++------------------------ 4 files changed, 34 insertions(+), 49 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 0e6318f6..140d2ed8 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -4,19 +4,16 @@ import io import json import re from datetime import datetime -from io import BytesIO from pathlib import Path from PIL import Image import modules.shared as shared +import modules.extensions as extensions_module from modules.extensions import apply_extensions from modules.html_generator import generate_chat_html from modules.text_generation import encode, generate_reply, get_max_prompt_length -if shared.args.picture and (shared.args.cai_chat or shared.args.chat): - import modules.bot_picture as bot_picture - # This gets the new line characters right. def clean_chat_message(text): text = text.replace('\n', '\n\n') @@ -84,18 +81,10 @@ def extract_message_from_reply(question, reply, current, other, check, extension return reply, next_character_found, substring_found -def generate_chat_picture(picture, name1, name2): - text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*' - buffer = BytesIO() - picture.save(buffer, format="JPEG") - img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') - visible_text = f'' - return text, visible_text - def stop_everything_event(): shared.stop_everything = True -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): shared.stop_everything = False just_started = True eos_token = '\n' if check else None @@ -103,13 +92,18 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical if 'pygmalion' in shared.model_name.lower(): name1 = "You" - if shared.args.picture and picture is not None: - text, visible_text = generate_chat_picture(picture, name1, name2) - else: + # Hijacking the input using an extension + visible_text = None + for extension, _ in extensions_module.iterator(): + if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: + text, visible_text = extension.input_hijack['value'] + extension.input_hijack['state'] = False + + if visible_text is None: visible_text = text if shared.args.chat: visible_text = visible_text.replace('\n', '
') - text = apply_extensions(text, "input") + text = apply_extensions(text, "input") prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) # Generate @@ -138,7 +132,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical break yield shared.history['visible'] -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): +def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): eos_token = '\n' if check else None if 'pygmalion' in shared.model_name.lower(): @@ -154,11 +148,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ break yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): - for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): +def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): + for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): yield generate_chat_html(_history, name1, name2, shared.character) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): +def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): if shared.character != 'None' and len(shared.history['visible']) == 1: if shared.args.cai_chat: yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) @@ -168,7 +162,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() - for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): + for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) diff --git a/modules/extensions.py b/modules/extensions.py index 6af7c58a..c0da496a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,5 +1,3 @@ -import gradio as gr - import extensions import modules.shared as shared diff --git a/modules/shared.py b/modules/shared.py index d4ffc19d..7b87d285 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,6 +14,9 @@ stop_everything = False # UI elements (buttons, sliders, HTML, etc) gradio = {} +# Generation input parameters +input_params = [] + settings = { 'max_new_tokens': 200, 'max_new_tokens_min': 1, @@ -40,7 +43,6 @@ parser.add_argument('--model', type=str, help='Name of the model to load by defa parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') -parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') diff --git a/server.py b/server.py index 5db441b6..32a1d2e4 100644 --- a/server.py +++ b/server.py @@ -185,7 +185,6 @@ else: if shared.args.chat or shared.args.cai_chat: with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as shared.gradio['interface']: - shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) if shared.args.cai_chat: shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) else: @@ -202,10 +201,6 @@ if shared.args.chat or shared.args.cai_chat: with gr.Row(): shared.gradio['Send last reply to input'] = gr.Button('Send last reply to input') shared.gradio['Replace last reply'] = gr.Button('Replace last reply') - if shared.args.picture: - with gr.Row(): - shared.gradio['picture_select'] = gr.Image(label='Send a picture', type='pil') - with gr.Tab('Chat settings'): shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') @@ -246,23 +241,19 @@ if shared.args.chat or shared.args.cai_chat: shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) with gr.Column(): shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) - create_settings_menus() + + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider']] if shared.args.extensions is not None: with gr.Tab('Extensions'): extensions_module.create_extensions_block() - input_params = [shared.gradio[i] for i in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider']] - if shared.args.picture: - input_params.append(shared.gradio['picture_select']) function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - gen_events.append(shared.gradio['Generate'].click(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen')) - gen_events.append(shared.gradio['textbox'].submit(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - if shared.args.picture: - shared.gradio['picture_select'].upload(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream) - gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) shared.gradio['Send last reply to input'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) @@ -284,13 +275,13 @@ if shared.args.chat or shared.args.cai_chat: shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], []) - if shared.args.picture: - shared.gradio['picture_select'].upload(lambda : None, [], [shared.gradio['picture_select']], show_progress=False) reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + + shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: @@ -311,10 +302,10 @@ elif shared.args.notebook: if shared.args.extensions is not None: extensions_module.create_extensions_block() - input_params = [shared.gradio[k] for k in ('textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping')] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] - gen_events.append(shared.gradio['Generate'].click(generate_reply, input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) - gen_events.append(shared.gradio['textbox'].submit(generate_reply, input_params, output_params, show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(None, None, None, cancels=gen_events) else: @@ -343,11 +334,11 @@ else: with gr.Tab('HTML'): shared.gradio['html'] = gr.HTML() - input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] - gen_events.append(shared.gradio['Generate'].click(generate_reply, input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) - gen_events.append(shared.gradio['textbox'].submit(generate_reply, input_params, output_params, show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + input_params[1:], output_params, show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(None, None, None, cancels=gen_events) shared.gradio['interface'].queue()