mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Fix send_pictures extension
This commit is contained in:
parent
49ce866c99
commit
80f4eabb2a
@ -6,6 +6,7 @@ import torch
|
|||||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||||
|
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
|
from modules.ui import gather_interface_values
|
||||||
|
|
||||||
# If 'state' is True, will hijack the next chat generation with
|
# If 'state' is True, will hijack the next chat generation with
|
||||||
# custom input text given by 'value' in the format [text, visible_text]
|
# custom input text given by 'value' in the format [text, visible_text]
|
||||||
@ -42,7 +43,7 @@ def ui():
|
|||||||
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
|
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
|
||||||
|
|
||||||
# Call the generation function
|
# Call the generation function
|
||||||
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
picture_select.upload(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
# Clear the picture from the upload field
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
picture_select.upload(lambda: None, [], [picture_select], show_progress=False)
|
lambda: None, None, picture_select, show_progress=False)
|
||||||
|
@ -2,6 +2,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
|
|
||||||
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
||||||
@ -14,6 +16,21 @@ with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
|||||||
chat_js = f.read()
|
chat_js = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def list_interface_input_elements(chat=False):
|
||||||
|
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
||||||
|
if chat:
|
||||||
|
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
def gather_interface_values(*args):
|
||||||
|
output = {}
|
||||||
|
for i, element in enumerate(shared.input_elements):
|
||||||
|
output[element] = args[i]
|
||||||
|
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||||
|
|
||||||
|
42
server.py
42
server.py
@ -24,7 +24,6 @@ from modules.LoRA import add_lora_to_model
|
|||||||
from modules.models import load_model, load_soft_prompt, unload_model
|
from modules.models import load_model, load_soft_prompt, unload_model
|
||||||
from modules.text_generation import generate_reply, stop_everything_event
|
from modules.text_generation import generate_reply, stop_everything_event
|
||||||
|
|
||||||
|
|
||||||
# Loading custom settings
|
# Loading custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||||
@ -361,21 +360,6 @@ else:
|
|||||||
title = 'Text generation web UI'
|
title = 'Text generation web UI'
|
||||||
|
|
||||||
|
|
||||||
def list_interface_input_elements(chat=False):
|
|
||||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
|
||||||
if chat:
|
|
||||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
|
||||||
return elements
|
|
||||||
|
|
||||||
|
|
||||||
def gather_interface_values(*args):
|
|
||||||
output = {}
|
|
||||||
for i, element in enumerate(shared.input_elements):
|
|
||||||
output[element] = args[i]
|
|
||||||
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def create_interface():
|
def create_interface():
|
||||||
gen_events = []
|
gen_events = []
|
||||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||||
@ -384,7 +368,7 @@ def create_interface():
|
|||||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||||
if shared.is_chat():
|
if shared.is_chat():
|
||||||
|
|
||||||
shared.input_elements = list_interface_input_elements(chat=True)
|
shared.input_elements = ui.list_interface_input_elements(chat=True)
|
||||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
shared.gradio['Chat input'] = gr.State()
|
shared.gradio['Chat input'] = gr.State()
|
||||||
|
|
||||||
@ -469,33 +453,33 @@ def create_interface():
|
|||||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
|
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Continue'].click(
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(
|
gen_events.append(shared.gradio['Impersonate'].click(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -551,7 +535,7 @@ def create_interface():
|
|||||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||||
|
|
||||||
elif shared.args.notebook:
|
elif shared.args.notebook:
|
||||||
shared.input_elements = list_interface_input_elements(chat=False)
|
shared.input_elements = ui.list_interface_input_elements(chat=False)
|
||||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -584,13 +568,13 @@ def create_interface():
|
|||||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
@ -599,7 +583,7 @@ def create_interface():
|
|||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
shared.input_elements = list_interface_input_elements(chat=False)
|
shared.input_elements = ui.list_interface_input_elements(chat=False)
|
||||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -630,19 +614,19 @@ def create_interface():
|
|||||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Continue'].click(
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user