mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Refactor the UI
A single dictionary called 'interface_state' is now passed as input to all functions. The values are updated only when necessary. The goal is to make it easier to add new elements to the UI.
This commit is contained in:
parent
64f5c90ee7
commit
0f212093a3
@ -74,16 +74,16 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
return prompt
|
||||
|
||||
|
||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
def extract_message_from_reply(reply, state):
|
||||
next_character_found = False
|
||||
|
||||
if stop_at_newline:
|
||||
if state['stop_at_newline']:
|
||||
lines = reply.split('\n')
|
||||
reply = lines[0].strip()
|
||||
if len(lines) > 1:
|
||||
next_character_found = True
|
||||
else:
|
||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
||||
for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
|
||||
idx = reply.find(string)
|
||||
if idx != -1:
|
||||
reply = reply[:idx]
|
||||
@ -92,7 +92,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
# If something like "\nYo" is generated just before "\nYou:"
|
||||
# is completed, trim it
|
||||
if not next_character_found:
|
||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
||||
for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
|
||||
for j in range(len(string) - 1, 0, -1):
|
||||
if reply[-j:] == string[:j]:
|
||||
reply = reply[:-j]
|
||||
@ -105,21 +105,18 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
if state['mode'] == 'instruct':
|
||||
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||
just_started = True
|
||||
name1_original = name1
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
|
||||
# Check if any extension wants to hijack this function call
|
||||
for extension, _ in extensions_module.iterator():
|
||||
@ -136,28 +133,28 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {
|
||||
'end_of_turn': end_of_turn,
|
||||
'is_instruct': mode == 'instruct',
|
||||
'end_of_turn': state['end_of_turn'],
|
||||
'is_instruct': state['mode'] == 'instruct',
|
||||
'_continue': _continue
|
||||
}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not any((regenerate, _continue)):
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
reply = cumulative_reply + reply
|
||||
|
||||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions(visible_reply, "output")
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
@ -171,7 +168,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
shared.history['visible'].append(['', ''])
|
||||
|
||||
if _continue:
|
||||
sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
|
||||
sep = list(map(lambda x: ' ' if x[-1] != ' ' else '', last_reply))
|
||||
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||
else:
|
||||
@ -188,28 +185,25 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
yield shared.history['visible']
|
||||
|
||||
|
||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
def impersonate_wrapper(text, state):
|
||||
if state['mode'] == 'instruct':
|
||||
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True)
|
||||
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
reply = cumulative_reply + reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
yield reply
|
||||
if next_character_found:
|
||||
break
|
||||
@ -220,32 +214,32 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
||||
yield reply
|
||||
|
||||
|
||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
yield chat_html_wrapper(history, name1, name2, mode)
|
||||
def cai_chatbot_wrapper(text, state):
|
||||
for history in chatbot_wrapper(text, state):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
|
||||
|
||||
|
||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
def regenerate_wrapper(text, state):
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
else:
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
# Yield '*Is typing...*'
|
||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
|
||||
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
|
||||
for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
|
||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
|
||||
|
||||
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
def continue_wrapper(text, state):
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
else:
|
||||
# Yield ' ...'
|
||||
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
|
||||
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
|
||||
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
|
||||
|
||||
def remove_last_message(name1, name2, mode):
|
||||
@ -284,7 +278,7 @@ def clear_chat_log(name1, name2, greeting, mode):
|
||||
if greeting != '':
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
|
||||
|
||||
# Save cleared logs
|
||||
save_history(mode)
|
||||
|
||||
@ -452,7 +446,7 @@ def load_character(character, name1, name2, mode):
|
||||
if greeting != "":
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
|
||||
|
||||
# Create .json log files since they don't already exist
|
||||
save_history(mode)
|
||||
|
||||
|
@ -69,6 +69,7 @@ def generate_softprompt_input_tensors(input_ids):
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
@ -77,6 +78,7 @@ def fix_gpt4chan(s):
|
||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||
return s
|
||||
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
@ -117,9 +119,9 @@ def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
|
||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(generate_state['seed'])
|
||||
seed = set_manual_seed(state['seed'])
|
||||
shared.stop_everything = False
|
||||
generate_params = {}
|
||||
t0 = time.time()
|
||||
@ -134,8 +136,8 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
# separately and terminate the function call earlier
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params['token_count'] = generate_state['max_new_tokens']
|
||||
generate_params[k] = state[k]
|
||||
generate_params['token_count'] = state['max_new_tokens']
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
@ -164,7 +166,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
|
||||
input_ids = encode(question, state['max_new_tokens'], add_bos_token=state['add_bos_token'])
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
@ -179,13 +181,13 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
|
||||
if not shared.args.flexgen:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params[k] = state[k]
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
else:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params['stop'] = generate_state['eos_token_ids'][-1]
|
||||
generate_params[k] = state[k]
|
||||
generate_params['stop'] = state['eos_token_ids'][-1]
|
||||
if not shared.args.no_stream:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
@ -248,7 +250,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(generate_state['max_new_tokens'] // 8 + 1):
|
||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
126
server.py
126
server.py
@ -184,22 +184,22 @@ def download_model_wrapper(repo_id):
|
||||
branch = "main"
|
||||
check = False
|
||||
|
||||
yield("Cleaning up the model/branch names")
|
||||
yield ("Cleaning up the model/branch names")
|
||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||
|
||||
yield("Getting the download links from Hugging Face")
|
||||
yield ("Getting the download links from Hugging Face")
|
||||
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||
|
||||
yield("Getting the output folder")
|
||||
yield ("Getting the output folder")
|
||||
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||
|
||||
if check:
|
||||
yield("Checking previously downloaded files")
|
||||
yield ("Checking previously downloaded files")
|
||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||
else:
|
||||
yield(f"Downloading files to {output_folder}")
|
||||
yield (f"Downloading files to {output_folder}")
|
||||
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
||||
yield("Done!")
|
||||
yield ("Done!")
|
||||
except:
|
||||
yield traceback.format_exc()
|
||||
|
||||
@ -357,6 +357,20 @@ else:
|
||||
title = 'Text generation web UI'
|
||||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
||||
return elements
|
||||
|
||||
|
||||
def gather_interface_values(*args):
|
||||
output = {}
|
||||
for i, element in enumerate(shared.input_elements):
|
||||
output[element] = args[i]
|
||||
return output
|
||||
|
||||
|
||||
def create_interface():
|
||||
gen_events = []
|
||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||
@ -364,7 +378,11 @@ def create_interface():
|
||||
|
||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
if shared.is_chat():
|
||||
|
||||
shared.input_elements = list_interface_input_elements(chat=True)
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
shared.gradio['Chat input'] = gr.State()
|
||||
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||
@ -384,7 +402,7 @@ def create_interface():
|
||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
|
||||
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
||||
shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
||||
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
|
||||
|
||||
with gr.Tab("Character", elem_id="chat-settings"):
|
||||
@ -439,75 +457,85 @@ def create_interface():
|
||||
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
|
||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
|
||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Impersonate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
)
|
||||
|
||||
shared.gradio['Replace last reply'].click(
|
||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
|
||||
shared.gradio['Clear history-confirm'].click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'mode']], shared.gradio['display']).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||
|
||||
shared.gradio['Stop'].click(
|
||||
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['Chat mode'].change(
|
||||
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
|
||||
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['Chat mode'], shared.gradio['character_menu']).then(
|
||||
shared.gradio['mode'].change(
|
||||
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['mode'], shared.gradio['Instruction templates']).then(
|
||||
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['mode'], shared.gradio['character_menu']).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['Instruction templates'].change(
|
||||
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['upload_chat_history'].upload(
|
||||
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['Chat mode'], shared.gradio['download'])
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
|
||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
shared.input_elements = list_interface_input_elements(chat=False)
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
@ -537,12 +565,23 @@ def create_interface():
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
|
||||
)
|
||||
|
||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
else:
|
||||
shared.input_elements = list_interface_input_elements(chat=False)
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -570,9 +609,22 @@ def create_interface():
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)
|
||||
)
|
||||
|
||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
@ -607,18 +659,6 @@ def create_interface():
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
def change_dict_value(d, key, value):
|
||||
d[key] = value
|
||||
return d
|
||||
|
||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
||||
if k not in shared.gradio:
|
||||
continue
|
||||
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
|
||||
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
||||
else:
|
||||
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
||||
|
||||
if not shared.is_chat():
|
||||
api.create_apis()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user