mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Move "send picture" into an extension
I am not proud of how I did it for now.
This commit is contained in:
parent
e51ece21c0
commit
7a527a5581
@ -4,19 +4,16 @@ import io
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
import modules.extensions as extensions_module
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_chat_html
|
||||
from modules.text_generation import encode, generate_reply, get_max_prompt_length
|
||||
|
||||
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
|
||||
import modules.bot_picture as bot_picture
|
||||
|
||||
# This gets the new line characters right.
|
||||
def clean_chat_message(text):
|
||||
text = text.replace('\n', '\n\n')
|
||||
@ -84,18 +81,10 @@ def extract_message_from_reply(question, reply, current, other, check, extension
|
||||
|
||||
return reply, next_character_found, substring_found
|
||||
|
||||
def generate_chat_picture(picture, name1, name2):
|
||||
text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*'
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
return text, visible_text
|
||||
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
||||
shared.stop_everything = False
|
||||
just_started = True
|
||||
eos_token = '\n' if check else None
|
||||
@ -103,13 +92,18 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
|
||||
if shared.args.picture and picture is not None:
|
||||
text, visible_text = generate_chat_picture(picture, name1, name2)
|
||||
else:
|
||||
# Hijacking the input using an extension
|
||||
visible_text = None
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
extension.input_hijack['state'] = False
|
||||
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
if shared.args.chat:
|
||||
visible_text = visible_text.replace('\n', '<br>')
|
||||
text = apply_extensions(text, "input")
|
||||
text = apply_extensions(text, "input")
|
||||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
||||
|
||||
# Generate
|
||||
@ -138,7 +132,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
||||
break
|
||||
yield shared.history['visible']
|
||||
|
||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
|
||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
||||
eos_token = '\n' if check else None
|
||||
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
@ -154,11 +148,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
|
||||
break
|
||||
yield reply
|
||||
|
||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
|
||||
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
|
||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
||||
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
||||
yield generate_chat_html(_history, name1, name2, shared.character)
|
||||
|
||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
|
||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
||||
if shared.character != 'None' and len(shared.history['visible']) == 1:
|
||||
if shared.args.cai_chat:
|
||||
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
||||
@ -168,7 +162,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
|
||||
if shared.args.cai_chat:
|
||||
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
||||
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
||||
|
@ -1,5 +1,3 @@
|
||||
import gradio as gr
|
||||
|
||||
import extensions
|
||||
import modules.shared as shared
|
||||
|
||||
|
@ -14,6 +14,9 @@ stop_everything = False
|
||||
# UI elements (buttons, sliders, HTML, etc)
|
||||
gradio = {}
|
||||
|
||||
# Generation input parameters
|
||||
input_params = []
|
||||
|
||||
settings = {
|
||||
'max_new_tokens': 200,
|
||||
'max_new_tokens_min': 1,
|
||||
@ -40,7 +43,6 @@ parser.add_argument('--model', type=str, help='Name of the model to load by defa
|
||||
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
|
||||
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
|
||||
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
|
||||
parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.')
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
|
39
server.py
39
server.py
@ -185,7 +185,6 @@ else:
|
||||
|
||||
if shared.args.chat or shared.args.cai_chat:
|
||||
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as shared.gradio['interface']:
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
||||
if shared.args.cai_chat:
|
||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
|
||||
else:
|
||||
@ -202,10 +201,6 @@ if shared.args.chat or shared.args.cai_chat:
|
||||
with gr.Row():
|
||||
shared.gradio['Send last reply to input'] = gr.Button('Send last reply to input')
|
||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||
if shared.args.picture:
|
||||
with gr.Row():
|
||||
shared.gradio['picture_select'] = gr.Image(label='Send a picture', type='pil')
|
||||
|
||||
with gr.Tab('Chat settings'):
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
||||
@ -246,23 +241,19 @@ if shared.args.chat or shared.args.cai_chat:
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
with gr.Column():
|
||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||
|
||||
create_settings_menus()
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider']]
|
||||
if shared.args.extensions is not None:
|
||||
with gr.Tab('Extensions'):
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
input_params = [shared.gradio[i] for i in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider']]
|
||||
if shared.args.picture:
|
||||
input_params.append(shared.gradio['picture_select'])
|
||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
if shared.args.picture:
|
||||
shared.gradio['picture_select'].upload(eval(function_call), input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
|
||||
|
||||
shared.gradio['Send last reply to input'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
@ -284,13 +275,13 @@ if shared.args.chat or shared.args.cai_chat:
|
||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
|
||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
|
||||
if shared.args.picture:
|
||||
shared.gradio['picture_select'].upload(lambda : None, [], [shared.gradio['picture_select']], show_progress=False)
|
||||
|
||||
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
|
||||
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
|
||||
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
@ -311,10 +302,10 @@ elif shared.args.notebook:
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
input_params = [shared.gradio[k] for k in ('textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping')]
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
|
||||
else:
|
||||
@ -343,11 +334,11 @@ else:
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
|
||||
shared.gradio['interface'].queue()
|
||||
|
Loading…
Reference in New Issue
Block a user