mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Fix send_pictures extension
This commit is contained in:
parent
49ce866c99
commit
80f4eabb2a
@ -6,6 +6,7 @@ import torch
|
||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||
|
||||
from modules import chat, shared
|
||||
from modules.ui import gather_interface_values
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation with
|
||||
# custom input text given by 'value' in the format [text, visible_text]
|
||||
@ -42,7 +43,7 @@ def ui():
|
||||
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
|
||||
|
||||
# Call the generation function
|
||||
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear the picture from the upload field
|
||||
picture_select.upload(lambda: None, [], [picture_select], show_progress=False)
|
||||
picture_select.upload(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda: None, None, picture_select, show_progress=False)
|
||||
|
@ -2,6 +2,8 @@ from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared
|
||||
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
|
||||
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
||||
@ -14,6 +16,21 @@ with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
||||
chat_js = f.read()
|
||||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
||||
return elements
|
||||
|
||||
|
||||
def gather_interface_values(*args):
|
||||
output = {}
|
||||
for i, element in enumerate(shared.input_elements):
|
||||
output[element] = args[i]
|
||||
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
|
||||
return output
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
|
42
server.py
42
server.py
@ -24,7 +24,6 @@ from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt, unload_model
|
||||
from modules.text_generation import generate_reply, stop_everything_event
|
||||
|
||||
|
||||
# Loading custom settings
|
||||
settings_file = None
|
||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||
@ -361,21 +360,6 @@ else:
|
||||
title = 'Text generation web UI'
|
||||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
||||
return elements
|
||||
|
||||
|
||||
def gather_interface_values(*args):
|
||||
output = {}
|
||||
for i, element in enumerate(shared.input_elements):
|
||||
output[element] = args[i]
|
||||
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
|
||||
return output
|
||||
|
||||
|
||||
def create_interface():
|
||||
gen_events = []
|
||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||
@ -384,7 +368,7 @@ def create_interface():
|
||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
if shared.is_chat():
|
||||
|
||||
shared.input_elements = list_interface_input_elements(chat=True)
|
||||
shared.input_elements = ui.list_interface_input_elements(chat=True)
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
shared.gradio['Chat input'] = gr.State()
|
||||
|
||||
@ -469,33 +453,33 @@ def create_interface():
|
||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Impersonate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
)
|
||||
|
||||
@ -551,7 +535,7 @@ def create_interface():
|
||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
shared.input_elements = list_interface_input_elements(chat=False)
|
||||
shared.input_elements = ui.list_interface_input_elements(chat=False)
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Row():
|
||||
@ -584,13 +568,13 @@ def create_interface():
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
@ -599,7 +583,7 @@ def create_interface():
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
else:
|
||||
shared.input_elements = list_interface_input_elements(chat=False)
|
||||
shared.input_elements = ui.list_interface_input_elements(chat=False)
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Row():
|
||||
@ -630,19 +614,19 @@ def create_interface():
|
||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then(
|
||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user