Implement sessions + add basic multi-user support (#2991)

This commit is contained in:
oobabooga 2023-07-04 00:03:30 -03:00 committed by GitHub
parent 1f8cae14f9
commit 4b1804a438
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 595 additions and 414 deletions

View File

@ -193,6 +193,7 @@ Optionally, you can use the following command-line flags:
| `-h`, `--help` | Show this help message and exit. | | `-h`, `--help` | Show this help message and exit. |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode. | | `--chat` | Launch the web UI in chat mode. |
| `--multi-user` | Multi-user mode. Chat histories are not saved or automatically loaded. WARNING: this is highly experimental. |
| `--character CHARACTER` | The name of the character to load in chat mode by default. | | `--character CHARACTER` | The name of the character to load in chat mode by default. |
| `--model MODEL` | Name of the model to load by default. | | `--model MODEL` | Name of the model to load by default. |
| `--lora LORA [LORA ...]` | The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. | | `--lora LORA [LORA ...]` | The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. |

View File

@ -6,6 +6,12 @@
padding-top: 2.5rem padding-top: 2.5rem
} }
.small-button {
max-width: 171px;
height: 39.594px;
align-self: end;
}
.refresh-button { .refresh-button {
max-width: 4.4em; max-width: 4.4em;
min-width: 2.2em !important; min-width: 2.2em !important;
@ -50,7 +56,7 @@ ol li p, ul li p {
display: inline-block; display: inline-block;
} }
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab { #main, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab {
border: 0; border: 0;
} }
@ -121,10 +127,6 @@ button {
font-size: 14px !important; font-size: 14px !important;
} }
.small-button {
max-width: 171px;
}
.file-saver { .file-saver {
position: fixed !important; position: fixed !important;
top: 50%; top: 50%;

View File

@ -38,11 +38,11 @@ script.py may define the special functions and variables below.
| `def ui()` | Creates custom gradio elements when the UI is launched. | | `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. | | `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. |
| `def custom_js()` | Same as above but for javascript. | | `def custom_js()` | Same as above but for javascript. |
| `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def input_modifier(string, state)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def output_modifier(string, state)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. |
| `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. |
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply. |
| `def custom_generate_reply(...)` | Overrides the main text generation function. | | `def custom_generate_reply(...)` | Overrides the main text generation function. |
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. | | `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. |

View File

@ -3,7 +3,9 @@ from pathlib import Path
import elevenlabs import elevenlabs
import gradio as gr import gradio as gr
from modules import chat, shared from modules import chat, shared
from modules.utils import gradio
params = { params = {
'activate': True, 'activate': True,
@ -35,24 +37,24 @@ def refresh_voices_dd():
return gr.Dropdown.update(value=all_voices[0], choices=all_voices) return gr.Dropdown.update(value=all_voices[0], choices=all_voices)
def remove_tts_from_history(): def remove_tts_from_history(history):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] history['visible'][i] = [history['visible'][i][0], entry[1]]
return history
def toggle_text_in_history(): def toggle_text_in_history(history):
for i, entry in enumerate(shared.history['visible']): for i, entry in enumerate(history['visible']):
visible_reply = entry[1] visible_reply = entry[1]
if visible_reply.startswith('<audio'): if visible_reply.startswith('<audio'):
if params['show_text']: if params['show_text']:
reply = shared.history['internal'][i][1] reply = history['internal'][i][1]
shared.history['visible'][i] = [ history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"
]
else: else:
shared.history['visible'][i] = [ history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"
] return history
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
@ -150,25 +152,24 @@ def ui():
convert_cancel = gr.Button('Cancel', visible=False) convert_cancel = gr.Button('Cancel', visible=False)
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False) convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
if shared.is_chat():
# Convert history with confirmation # Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel] convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click( convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, None, None).then( remove_tts_from_history, gradio('history'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history # Toggle message text in history
show_text.change( show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then( lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, None, None).then( toggle_text_in_history, gradio('history'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({'activate': x}), activate, None) activate.change(lambda x: params.update({'activate': x}), activate, None)

View File

@ -10,7 +10,6 @@ import requests
import torch import torch
from PIL import Image from PIL import Image
import modules.shared as shared
from modules.models import reload_model, unload_model from modules.models import reload_model, unload_model
from modules.ui import create_refresh_button from modules.ui import create_refresh_button
@ -126,7 +125,7 @@ def input_modifier(string):
return string return string
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description, character):
global params global params
@ -160,7 +159,7 @@ def get_SD_pictures(description):
if params['save_img']: if params['save_img']:
img_data = base64.b64decode(img_str) img_data = base64.b64decode(img_str)
variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}' variadic = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}'
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png') output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
output_file.parent.mkdir(parents=True, exist_ok=True) output_file.parent.mkdir(parents=True, exist_ok=True)
@ -186,7 +185,7 @@ def get_SD_pictures(description):
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging? # and replace it with 'text' for the purposes of logging?
def output_modifier(string): def output_modifier(string, state):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
@ -213,7 +212,7 @@ def output_modifier(string):
else: else:
text = string text = string
string = get_SD_pictures(string) + "\n" + text string = get_SD_pictures(string, state['character_menu']) + "\n" + text
return string return string

View File

@ -3,9 +3,10 @@ from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
from modules import chat, shared
from extensions.silero_tts import tts_preprocessor from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.utils import gradio
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
@ -56,20 +57,24 @@ def load_model():
return model return model
def remove_tts_from_history(): def remove_tts_from_history(history):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] history['visible'][i] = [history['visible'][i][0], entry[1]]
return history
def toggle_text_in_history(): def toggle_text_in_history(history):
for i, entry in enumerate(shared.history['visible']): for i, entry in enumerate(history['visible']):
visible_reply = entry[1] visible_reply = entry[1]
if visible_reply.startswith('<audio'): if visible_reply.startswith('<audio'):
if params['show_text']: if params['show_text']:
reply = shared.history['internal'][i][1] reply = history['internal'][i][1]
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"] history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
else: else:
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"] history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return history
def state_modifier(state): def state_modifier(state):
@ -80,7 +85,7 @@ def state_modifier(state):
return state return state
def input_modifier(string): def input_modifier(string, state):
if not params['activate']: if not params['activate']:
return string return string
@ -99,7 +104,7 @@ def history_modifier(history):
return history return history
def output_modifier(string): def output_modifier(string, state):
global model, current_params, streaming_state global model, current_params, streaming_state
for i in params: for i in params:
if params[i] != current_params[i]: if params[i] != current_params[i]:
@ -116,7 +121,7 @@ def output_modifier(string):
if string == '': if string == '':
string = '*Empty reply, try regenerating*' string = '*Empty reply, try regenerating*'
else: else:
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav') output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch']) prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>' silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
@ -155,23 +160,24 @@ def ui():
gr.Markdown('[Click here for Silero audio samples](https://oobabooga.github.io/silero-samples/index.html)') gr.Markdown('[Click here for Silero audio samples](https://oobabooga.github.io/silero-samples/index.html)')
if shared.is_chat():
# Convert history with confirmation # Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel] convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click( convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, None, None).then( remove_tts_from_history, gradio('history'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history # Toggle message text in history
show_text.change( show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then( lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, None, None).then( toggle_text_in_history, gradio('history'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None) activate.change(lambda x: params.update({"activate": x}), activate, None)

View File

@ -96,6 +96,8 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
def custom_generate_chat_prompt(user_input, state, **kwargs): def custom_generate_chat_prompt(user_input, state, **kwargs):
global chat_collector global chat_collector
history = state['history']
if state['mode'] == 'instruct': if state['mode'] == 'instruct':
results = collector.get_sorted(user_input, n_results=params['chunk_count']) results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results) additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
@ -104,29 +106,29 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
def make_single_exchange(id_): def make_single_exchange(id_):
output = '' output = ''
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n" output += f"{state['name1']}: {history['internal'][id_][0]}\n"
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n" output += f"{state['name2']}: {history['internal'][id_][1]}\n"
return output return output
if len(shared.history['internal']) > params['chunk_count'] and user_input != '': if len(history['internal']) > params['chunk_count'] and user_input != '':
chunks = [] chunks = []
hist_size = len(shared.history['internal']) hist_size = len(history['internal'])
for i in range(hist_size-1): for i in range(hist_size-1):
chunks.append(make_single_exchange(i)) chunks.append(make_single_exchange(i))
add_chunks_to_collector(chunks, chat_collector) add_chunks_to_collector(chunks, chat_collector)
query = '\n'.join(shared.history['internal'][-1] + [user_input]) query = '\n'.join(history['internal'][-1] + [user_input])
try: try:
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight']) best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
additional_context = '\n' additional_context = '\n'
for id_ in best_ids: for id_ in best_ids:
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
additional_context += make_single_exchange(id_) additional_context += make_single_exchange(id_)
logger.warning(f'Adding the following new context:\n{additional_context}') logger.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context state['context'] = state['context'].strip() + '\n' + additional_context
kwargs['history'] = { kwargs['history'] = {
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids], 'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': '' 'visible': ''
} }
except RuntimeError: except RuntimeError:

View File

@ -3,7 +3,6 @@ import copy
import functools import functools
import json import json
import re import re
from datetime import datetime
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
@ -11,7 +10,6 @@ import yaml
from PIL import Image from PIL import Image
import modules.shared as shared import modules.shared as shared
from modules import utils
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper, make_thumbnail from modules.html_generator import chat_html_wrapper, make_thumbnail
from modules.logging_colors import logger from modules.logging_colors import logger
@ -20,7 +18,12 @@ from modules.text_generation import (
get_encoded_length, get_encoded_length,
get_max_prompt_length get_max_prompt_length
) )
from modules.utils import delete_file, replace_all, save_file from modules.utils import (
delete_file,
get_available_characters,
replace_all,
save_file
)
def get_turn_substrings(state, instruct=False): def get_turn_substrings(state, instruct=False):
@ -54,7 +57,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False) impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False) _continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False) also_return_rows = kwargs.get('also_return_rows', False)
history = kwargs.get('history', shared.history)['internal'] history = kwargs.get('history', state['history'])['internal']
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
# Find the maximum prompt size # Find the maximum prompt size
@ -76,10 +79,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
if impersonate: if impersonate:
wrapper += substrings['user_turn_stripped'].rstrip(' ') wrapper += substrings['user_turn_stripped'].rstrip(' ')
elif _continue: elif _continue:
wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped']) wrapper += apply_extensions('bot_prefix', substrings['bot_turn_stripped'], state)
wrapper += history[-1][1] wrapper += history[-1][1]
else: else:
wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')) wrapper += apply_extensions('bot_prefix', substrings['bot_turn_stripped'].rstrip(' '), state)
else: else:
wrapper = '<|prompt|>' wrapper = '<|prompt|>'
@ -113,7 +116,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Add the character prefix # Add the character prefix
if state['mode'] != 'chat-instruct': if state['mode'] != 'chat-instruct':
rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' '))) rows.append(apply_extensions('bot_prefix', substrings['bot_turn_stripped'].rstrip(' '), state))
while len(rows) > min_rows and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) >= max_length: while len(rows) > min_rows and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) >= max_length:
rows.pop(1) rows.pop(1)
@ -153,7 +156,8 @@ def get_stopping_strings(state):
return stopping_strings return stopping_strings
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True): def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True):
history = state['history']
output = copy.deepcopy(history) output = copy.deepcopy(history)
output = apply_extensions('history', output) output = apply_extensions('history', output)
state = apply_extensions('state', state) state = apply_extensions('state', state)
@ -174,11 +178,11 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
text = apply_extensions('input', text, state)
# *Is typing...* # *Is typing...*
if loading_message: if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
text = apply_extensions('input', text)
else: else:
text, visible_text = output['internal'][-1][0], output['visible'][-1][0] text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate: if regenerate:
@ -215,7 +219,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
# otherwise gradio gets confused # otherwise gradio gets confused
if shared.stop_everything: if shared.stop_everything:
output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1]) output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state)
yield output yield output
return return
@ -241,7 +245,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
else: else:
cumulative_reply = reply cumulative_reply = reply
output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1]) output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state)
yield output yield output
@ -274,14 +278,15 @@ def impersonate_wrapper(text, start_with, state):
yield cumulative_reply.lstrip(' ') yield cumulative_reply.lstrip(' ')
def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True): def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True):
history = state['history']
if regenerate or _continue: if regenerate or _continue:
text = '' text = ''
if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0: if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
yield history yield history
return return
for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message): for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
yield history yield history
@ -296,144 +301,116 @@ def generate_chat_reply_wrapper(text, start_with, state, regenerate=False, _cont
send_dummy_message(text) send_dummy_message(text)
send_dummy_reply(start_with) send_dummy_reply(start_with)
for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)): for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True)):
if i != 0: yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style']), history
shared.history = copy.deepcopy(history)
yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style'])
def remove_last_message(): def remove_last_message(history):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': if len(history['visible']) > 0 and history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop() last = history['visible'].pop()
shared.history['internal'].pop() history['internal'].pop()
else: else:
last = ['', ''] last = ['', '']
return last[0] return last[0], history
def send_last_reply_to_input(): def send_last_reply_to_input(history):
if len(shared.history['internal']) > 0: if len(history['internal']) > 0:
return shared.history['internal'][-1][1] return history['internal'][-1][1]
else: else:
return '' return ''
def replace_last_reply(text): def replace_last_reply(text, state):
if len(shared.history['visible']) > 0: history = state['history']
shared.history['visible'][-1][1] = text if len(history['visible']) > 0:
shared.history['internal'][-1][1] = apply_extensions("input", text) history['visible'][-1][1] = text
history['internal'][-1][1] = apply_extensions('input', text, state)
return history
def send_dummy_message(text): def send_dummy_message(text, state):
shared.history['visible'].append([text, '']) history = state['history']
shared.history['internal'].append([apply_extensions("input", text), '']) history['visible'].append([text, ''])
history['internal'].append([apply_extensions('input', text, state), ''])
return history
def send_dummy_reply(text): def send_dummy_reply(text, state):
if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '': history = state['history']
shared.history['visible'].append(['', '']) if len(history['visible']) > 0 and not history['visible'][-1][1] == '':
shared.history['internal'].append(['', '']) history['visible'].append(['', ''])
history['internal'].append(['', ''])
shared.history['visible'][-1][1] = text history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions("input", text) history['internal'][-1][1] = apply_extensions('input', text, state)
return history
def clear_chat_log(greeting, mode): def clear_chat_log(state):
shared.history['visible'] = [] greeting = state['greeting']
shared.history['internal'] = [] mode = state['mode']
history = state['history']
history['visible'] = []
history['internal'] = []
if mode != 'instruct': if mode != 'instruct':
if greeting != '': if greeting != '':
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions("output", greeting)]] history['visible'] += [['', apply_extensions('output', greeting, state)]]
save_history(mode)
def redraw_html(name1, name2, mode, style, reset_cache=False):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, style, reset_cache=reset_cache)
def tokenize_dialogue(dialogue, name1, name2):
history = []
messages = []
dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('<start>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue)
idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)]
if len(idx) == 0:
return history
for i in range(len(idx) - 1):
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
messages.append(dialogue[idx[-1]:].strip())
entry = ['', '']
for i in messages:
if i.startswith(f'{name1}:'):
entry[0] = i[len(f'{name1}:'):].strip()
elif i.startswith(f'{name2}:'):
entry[1] = i[len(f'{name2}:'):].strip()
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
history.append(entry)
entry = ['', '']
print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
for row in history:
for column in row:
print("\n")
for line in column.strip().split('\n'):
print("| " + line + "\n")
print("|\n")
print("------------------------------")
return history return history
def save_history(mode, timestamp=False, user_request=False): def redraw_html(history, name1, name2, mode, style, reset_cache=False):
# Instruct mode histories should not be saved as if return chat_html_wrapper(history, name1, name2, mode, style, reset_cache=reset_cache)
# Alpaca or Vicuna were characters
if mode == 'instruct':
if not timestamp:
return
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else:
if shared.character == 'None' and not user_request:
return
if timestamp:
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else:
fname = f"{shared.character}_persistent.json"
if not Path('logs').exists():
Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}')
def load_history(file, name1, name2): def save_history(history, path=None):
file = file.decode('utf-8') p = path or Path('logs/exported_history.json')
with open(p, 'w', encoding='utf-8') as f:
f.write(json.dumps(history, indent=4))
return p
def load_history(file, history):
try: try:
file = file.decode('utf-8')
j = json.loads(file) j = json.loads(file)
if 'data' in j: if 'internal' in j and 'visible' in j:
shared.history['internal'] = j['data'] return j
if 'data_visible' in j:
shared.history['visible'] = j['data_visible']
else: else:
shared.history['visible'] = copy.deepcopy(shared.history['internal']) return history
except: except:
shared.history['internal'] = tokenize_dialogue(file, name1, name2) return history
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def save_persistent_history(history, character, mode):
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user:
save_history(history, path=Path(f'logs/{character}_persistent.json'))
def load_persistent_history(state):
if shared.args.multi_user or state['mode'] == 'instruct':
return state['history']
character = state['character_menu']
greeting = state['greeting']
p = Path(f'logs/{character}_persistent.json')
if character not in ['None', '', None] and p.exists():
f = json.loads(open(p, 'rb').read())
if 'internal' in f and 'visible' in f:
history = f
else:
history = {'internal': [], 'visible': []}
if greeting != "":
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
history['visible'] += [['', apply_extensions('output', greeting, state)]]
return history
def replace_character_names(text, name1, name2): def replace_character_names(text, name1, name2):
@ -468,7 +445,6 @@ def generate_pfp_cache(character):
def load_character(character, name1, name2, instruct=False): def load_character(character, name1, name2, instruct=False):
shared.character = character
context = greeting = turn_template = "" context = greeting = turn_template = ""
greeting_field = 'greeting' greeting_field = 'greeting'
picture = None picture = None
@ -477,7 +453,7 @@ def load_character(character, name1, name2, instruct=False):
if Path("cache/pfp_character.png").exists(): if Path("cache/pfp_character.png").exists():
Path("cache/pfp_character.png").unlink() Path("cache/pfp_character.png").unlink()
if character != 'None': if character not in ['None', '', None]:
folder = 'characters' if not instruct else 'characters/instruction-following' folder = 'characters' if not instruct else 'characters/instruction-following'
picture = generate_pfp_cache(character) picture = generate_pfp_cache(character)
for extension in ["yml", "yaml", "json"]: for extension in ["yml", "yaml", "json"]:
@ -527,20 +503,6 @@ def load_character(character, name1, name2, instruct=False):
greeting = shared.settings['greeting'] greeting = shared.settings['greeting']
turn_template = shared.settings['turn_template'] turn_template = shared.settings['turn_template']
if not instruct:
shared.history['internal'] = []
shared.history['visible'] = []
if shared.character != 'None' and Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
else:
# Insert greeting if it exists
if greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
# Create .json log files since they don't already exist
save_history('instruct' if instruct else 'chat')
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n") return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
@ -568,7 +530,7 @@ def upload_character(json_file, img, tavern=False):
img.save(Path(f'characters/{outfile_name}.png')) img.save(Path(f'characters/{outfile_name}.png'))
logger.info(f'New character saved to "characters/{outfile_name}.json".') logger.info(f'New character saved to "characters/{outfile_name}.json".')
return gr.update(value=outfile_name, choices=utils.get_available_characters()) return gr.update(value=outfile_name, choices=get_available_characters())
def upload_tavern_character(img, _json): def upload_tavern_character(img, _json):

View File

@ -6,6 +6,8 @@ import gradio as gr
import extensions import extensions
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger from modules.logging_colors import logger
from inspect import signature
state = {} state = {}
available_extensions = [] available_extensions = []
@ -52,10 +54,14 @@ def iterator():
# Extension functions that map string -> string # Extension functions that map string -> string
def _apply_string_extensions(function_name, text): def _apply_string_extensions(function_name, text, state):
for extension, _ in iterator(): for extension, _ in iterator():
if hasattr(extension, function_name): if hasattr(extension, function_name):
text = getattr(extension, function_name)(text) func = getattr(extension, function_name)
if len(signature(func).parameters) == 2:
text = func(text, state)
else:
text = func(text)
return text return text

View File

@ -14,16 +14,20 @@ def clone_or_pull_repository(github_url):
# Check if the repository is already cloned # Check if the repository is already cloned
if os.path.exists(repo_path): if os.path.exists(repo_path):
yield f"Updating {github_url}..."
# Perform a 'git pull' to update the repository # Perform a 'git pull' to update the repository
try: try:
pull_output = subprocess.check_output(["git", "-C", repo_path, "pull"], stderr=subprocess.STDOUT) pull_output = subprocess.check_output(["git", "-C", repo_path, "pull"], stderr=subprocess.STDOUT)
yield "Done."
return pull_output.decode() return pull_output.decode()
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
return str(e) return str(e)
# Clone the repository # Clone the repository
try: try:
yield f"Cloning {github_url}..."
clone_output = subprocess.check_output(["git", "clone", github_url, repo_path], stderr=subprocess.STDOUT) clone_output = subprocess.check_output(["git", "clone", github_url, repo_path], stderr=subprocess.STDOUT)
yield "Done."
return clone_output.decode() return clone_output.decode()
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
return str(e) return str(e)

View File

@ -266,8 +266,8 @@ def generate_chat_html(history, name1, name2, reset_cache=False):
def chat_html_wrapper(history, name1, name2, mode, style, reset_cache=False): def chat_html_wrapper(history, name1, name2, mode, style, reset_cache=False):
if mode == 'instruct': if mode == 'instruct':
return generate_instruct_html(history) return generate_instruct_html(history['visible'])
elif style == 'wpp': elif style == 'wpp':
return generate_chat_html(history, name1, name2) return generate_chat_html(history['visible'], name1, name2)
else: else:
return generate_cai_chat_html(history, name1, name2, style, reset_cache) return generate_cai_chat_html(history['visible'], name1, name2, style, reset_cache)

View File

@ -29,6 +29,7 @@ def load_preset(name):
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
} }
if name not in ['None', None, '']:
with open(Path(f'presets/{name}.yaml'), 'r') as infile: with open(Path(f'presets/{name}.yaml'), 'r') as infile:
preset = yaml.safe_load(infile) preset = yaml.safe_load(infile)

View File

@ -14,8 +14,6 @@ model_name = "None"
lora_names = [] lora_names = []
# Chat variables # Chat variables
history = {'internal': [], 'visible': []}
character = 'None'
stop_everything = False stop_everything = False
processing_message = '*Is typing...*' processing_message = '*Is typing...*'
@ -83,6 +81,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
# Basic settings # Basic settings
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
parser.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. WARNING: this is highly experimental.')
parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.') parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.')
parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, nargs="+", help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.') parser.add_argument('--lora', type=str, nargs="+", help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
@ -204,6 +203,8 @@ if args.trust_remote_code:
logger.warning("trust_remote_code is enabled. This is dangerous.") logger.warning("trust_remote_code is enabled. This is dangerous.")
if args.share: if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if args.multi_user:
logger.warning("The multi-user mode is highly experimental. DO NOT EXPOSE IT TO THE INTERNET.")
def fix_loader_name(name): def fix_loader_name(name):
@ -246,6 +247,15 @@ def is_chat():
return args.chat return args.chat
def get_mode():
if args.chat:
return 'chat'
elif args.notebook:
return 'notebook'
else:
return 'default'
# Loading model-specific settings # Loading model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p: with Path(f'{args.model_dir}/config.yaml') as p:
if p.exists(): if p.exists():

View File

@ -190,7 +190,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
original_question = question original_question = question
if not is_chat: if not is_chat:
state = apply_extensions('state', state) state = apply_extensions('state', state)
question = apply_extensions('input', question) question = apply_extensions('input', question, state)
# Finding the stopping strings # Finding the stopping strings
all_stop_strings = [] all_stop_strings = []
@ -223,7 +223,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
break break
if not is_chat: if not is_chat:
reply = apply_extensions('output', reply) reply = apply_extensions('output', reply, state)
yield reply yield reply
@ -262,7 +262,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
generate_params['eos_token_id'] = eos_token_ids generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()); generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
t0 = time.time() t0 = time.time()
try: try:

View File

@ -1,3 +1,4 @@
import json
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
@ -5,6 +6,7 @@ import torch
from modules import shared from modules import shared
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
css = f.read() css = f.read()
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f: with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
@ -14,7 +16,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
chat_js = f.read() chat_js = f.read()
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '🔄'
delete_symbol = '🗑️' delete_symbol = '🗑️'
save_symbol = '💾' save_symbol = '💾'
@ -30,17 +32,103 @@ theme = gr.themes.Default(
def list_model_elements(): def list_model_elements():
elements = ['loader', 'cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'trust_remote_code', 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'triton', 'desc_act', 'no_inject_fused_attention', 'no_inject_fused_mlp', 'no_use_cuda_fp16', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers', 'n_ctx', 'llama_cpp_seed', 'gpu_split', 'max_seq_len', 'compress_pos_emb'] elements = [
'loader',
'cpu_memory',
'auto_devices',
'disk',
'cpu',
'bf16',
'load_in_8bit',
'trust_remote_code',
'load_in_4bit',
'compute_dtype',
'quant_type',
'use_double_quant',
'wbits',
'groupsize',
'model_type',
'pre_layer',
'triton',
'desc_act',
'no_inject_fused_attention',
'no_inject_fused_mlp',
'no_use_cuda_fp16',
'threads',
'n_batch',
'no_mmap',
'mlock',
'n_gpu_layers',
'n_ctx',
'llama_cpp_seed',
'gpu_split',
'max_seq_len',
'compress_pos_emb'
]
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
elements.append(f'gpu_memory_{i}') elements.append(f'gpu_memory_{i}')
return elements return elements
def list_interface_input_elements(chat=False): def list_interface_input_elements():
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a'] elements = [
if chat: 'preset_menu',
elements += ['name1', 'name2', 'greeting', 'context', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command'] 'max_new_tokens',
'seed',
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'add_bos_token',
'ban_eos_token',
'truncation_length',
'custom_stopping_strings',
'skip_special_tokens',
'stream',
'tfs',
'top_a',
]
if shared.args.chat:
elements += [
'character_menu',
'history',
'name1',
'name2',
'greeting',
'context',
'chat_generation_attempts',
'stop_at_newline',
'mode',
'instruction_template',
'name1_instruct',
'name2_instruct',
'context_instruct',
'turn_template',
'chat_style',
'chat-instruct_command',
]
else:
elements.append('textbox')
if not shared.args.notebook:
elements.append('output_textbox')
elements += list_model_elements() elements += list_model_elements()
return elements return elements
@ -48,10 +136,14 @@ def list_interface_input_elements(chat=False):
def gather_interface_values(*args): def gather_interface_values(*args):
output = {} output = {}
for i, element in enumerate(shared.input_elements): for i, element in enumerate(list_interface_input_elements()):
output[element] = args[i] output[element] = args[i]
if not shared.args.multi_user:
shared.persistent_interface_state = output shared.persistent_interface_state = output
with open(Path(f'logs/session_{shared.get_mode()}_autosave.json'), 'w') as f:
f.write(json.dumps(output, indent=4))
return output return output
@ -59,11 +151,12 @@ def apply_interface_values(state, use_persistent=False):
if use_persistent: if use_persistent:
state = shared.persistent_interface_state state = shared.persistent_interface_state
elements = list_interface_input_elements(chat=shared.is_chat()) elements = list_interface_input_elements()
if len(state) == 0: if len(state) == 0:
return [gr.update() for k in elements] # Dummy, do nothing return [gr.update() for k in elements] # Dummy, do nothing
else: else:
return [state[k] if k in state else gr.update() for k in elements] ans = [state[k] if k in state else gr.update() for k in elements]
return ans
class ToolButton(gr.Button, gr.components.FormComponent): class ToolButton(gr.Button, gr.components.FormComponent):
@ -92,6 +185,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
inputs=[], inputs=[],
outputs=[refresh_component] outputs=[refresh_component]
) )
return refresh_button return refresh_button

View File

@ -7,6 +7,14 @@ from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
# Helper function to get multiple values from shared.gradio
def gradio(*keys):
if len(keys) == 1 and type(keys[0]) is list:
keys = keys[0]
return [shared.gradio[k] for k in keys]
def save_file(fname, contents): def save_file(fname, contents):
if fname == '': if fname == '':
logger.error('File name is empty!') logger.error('File name is empty!')
@ -111,3 +119,8 @@ def get_datasets(path: str, ext: str):
def get_available_chat_styles(): def get_available_chat_styles():
return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys) return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
def get_available_sessions():
items = sorted(set(k.stem for k in Path('logs').glob(f'session_{shared.get_mode()}*')), key=natural_keys, reverse=True)
return [item for item in items if 'autosave' in item] + [item for item in items if 'autosave' not in item]

396
server.py
View File

@ -49,6 +49,7 @@ from modules.text_generation import (
get_encoded_length, get_encoded_length,
stop_everything_event stop_everything_event
) )
from modules.utils import gradio
def load_model_wrapper(selected_model, loader, autoload=False): def load_model_wrapper(selected_model, loader, autoload=False):
@ -257,40 +258,40 @@ def create_model_menus():
with gr.Row(): with gr.Row():
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready') shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
shared.gradio['loader'].change(loaders.make_loader_params_visible, shared.gradio['loader'], [shared.gradio[k] for k in loaders.get_all_params()]) shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()))
# In this event handler, the interface state is read and updated # In this event handler, the interface state is read and updated
# with the model defaults (if any), and then the model is loaded # with the model defaults (if any), and then the model is loaded
# unless "autoload_model" is unchecked # unless "autoload_model" is unchecked
shared.gradio['model_menu'].change( shared.gradio['model_menu'].change(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
apply_model_settings_to_state, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['interface_state']).then( apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then(
ui.apply_interface_values, shared.gradio['interface_state'], [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then( ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
update_model_parameters, shared.gradio['interface_state'], None).then( update_model_parameters, gradio('interface_state'), None).then(
load_model_wrapper, [shared.gradio[k] for k in ['model_menu', 'loader', 'autoload_model']], shared.gradio['model_status'], show_progress=False) load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False)
load.click( load.click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
update_model_parameters, shared.gradio['interface_state'], None).then( update_model_parameters, gradio('interface_state'), None).then(
partial(load_model_wrapper, autoload=True), [shared.gradio[k] for k in ['model_menu', 'loader']], shared.gradio['model_status'], show_progress=False) partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
unload.click( unload.click(
unload_model, None, None).then( unload_model, None, None).then(
lambda: "Model unloaded", None, shared.gradio['model_status']) lambda: "Model unloaded", None, gradio('model_status'))
reload.click( reload.click(
unload_model, None, None).then( unload_model, None, None).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
update_model_parameters, shared.gradio['interface_state'], None).then( update_model_parameters, gradio('interface_state'), None).then(
partial(load_model_wrapper, autoload=True), [shared.gradio[k] for k in ['model_menu', 'loader']], shared.gradio['model_status'], show_progress=False) partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
save_settings.click( save_settings.click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False) save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=True) shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu'), gradio('model_status'), show_progress=True)
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), load)
def create_chat_settings_menus(): def create_chat_settings_menus():
@ -409,7 +410,7 @@ def create_settings_menus(default_preset):
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a'))
def create_file_saving_menus(): def create_file_saving_menus():
@ -448,39 +449,80 @@ def create_file_saving_menus():
def create_file_saving_event_handlers(): def create_file_saving_event_handlers():
shared.gradio['save_confirm'].click( shared.gradio['save_confirm'].click(
lambda x, y, z: utils.save_file(x + y, z), [shared.gradio[k] for k in ['save_root', 'save_filename', 'save_contents']], None).then( lambda x, y, z: utils.save_file(x + y, z), gradio('save_root', 'save_filename', 'save_contents'), None).then(
lambda: gr.update(visible=False), None, shared.gradio['file_saver']) lambda: gr.update(visible=False), None, gradio('file_saver'))
shared.gradio['delete_confirm'].click( shared.gradio['delete_confirm'].click(
lambda x, y: utils.delete_file(x + y), [shared.gradio[k] for k in ['delete_root', 'delete_filename']], None).then( lambda x, y: utils.delete_file(x + y), gradio('delete_root', 'delete_filename'), None).then(
lambda: gr.update(visible=False), None, shared.gradio['file_deleter']) lambda: gr.update(visible=False), None, gradio('file_deleter'))
shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['file_deleter']) shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_deleter'))
shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['file_saver']) shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_saver'))
if shared.is_chat(): if shared.is_chat():
shared.gradio['save_character_confirm'].click( shared.gradio['save_character_confirm'].click(
chat.save_character, [shared.gradio[k] for k in ['name2', 'greeting', 'context', 'character_picture', 'save_character_filename']], None).then( chat.save_character, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), None).then(
lambda: gr.update(visible=False), None, shared.gradio['character_saver']) lambda: gr.update(visible=False), None, gradio('character_saver'))
shared.gradio['delete_character_confirm'].click( shared.gradio['delete_character_confirm'].click(
chat.delete_character, shared.gradio['character_menu'], None).then( chat.delete_character, gradio('character_menu'), None).then(
lambda: gr.update(visible=False), None, shared.gradio['character_deleter']).then( lambda: gr.update(visible=False), None, gradio('character_deleter')).then(
lambda: gr.update(choices=utils.get_available_characters()), outputs=shared.gradio['character_menu']) lambda: gr.update(choices=utils.get_available_characters()), None, gradio('character_menu'))
shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['character_saver']) shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'))
shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['character_deleter']) shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'))
shared.gradio['save_preset'].click( shared.gradio['save_preset'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then( presets.generate_preset_yaml, gradio('interface_state'), gradio('save_contents')).then(
lambda: 'presets/', None, shared.gradio['save_root']).then( lambda: 'presets/', None, gradio('save_root')).then(
lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then( lambda: 'My Preset.yaml', None, gradio('save_filename')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver']) lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['delete_preset'].click( shared.gradio['delete_preset'].click(
lambda x: f'{x}.yaml', shared.gradio['preset_menu'], shared.gradio['delete_filename']).then( lambda x: f'{x}.yaml', gradio('preset_menu'), gradio('delete_filename')).then(
lambda: 'presets/', None, shared.gradio['delete_root']).then( lambda: 'presets/', None, gradio('delete_root')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_deleter']) lambda: gr.update(visible=True), None, gradio('file_deleter'))
if not shared.args.multi_user:
def load_session(session, state):
with open(Path(f'logs/{session}.json'), 'r') as f:
state.update(json.loads(f.read()))
if shared.is_chat():
chat.save_persistent_history(state['history'], state['character_menu'], state['mode'])
return state
if shared.is_chat():
shared.gradio['save_session'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('save_contents')).then(
lambda: 'logs/', None, gradio('save_root')).then(
lambda x: f'session_{shared.get_mode()}_{x + "_" if x not in ["None", None, ""] else ""}{utils.current_time()}.json', gradio('character_menu'), gradio('save_filename')).then(
lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['session_menu'].change(
load_session, gradio('session_menu', 'interface_state'), gradio('interface_state')).then(
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
else:
shared.gradio['save_session'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('save_contents')).then(
lambda: 'logs/', None, gradio('save_root')).then(
lambda: f'session_{shared.get_mode()}_{utils.current_time()}.json', None, gradio('save_filename')).then(
lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['session_menu'].change(
load_session, gradio('session_menu', 'interface_state'), gradio('interface_state')).then(
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False)
shared.gradio['delete_session'].click(
lambda x: f'{x}.json', gradio('session_menu'), gradio('delete_filename')).then(
lambda: 'logs/', None, gradio('delete_root')).then(
lambda: gr.update(visible=True), None, gradio('file_deleter'))
def set_interface_arguments(interface_mode, extensions, bool_active): def set_interface_arguments(interface_mode, extensions, bool_active):
@ -544,13 +586,17 @@ def create_interface():
# Create chat mode interface # Create chat mode interface
if shared.is_chat(): if shared.is_chat():
shared.input_elements = ui.list_interface_input_elements(chat=True) shared.input_elements = ui.list_interface_input_elements()
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['Chat input'] = gr.State() shared.gradio.update({
shared.gradio['dummy'] = gr.State() 'interface_state': gr.State({k: None for k in shared.input_elements}),
'Chat input': gr.State(),
'dummy': gr.State(),
'history': gr.State({'internal': [], 'visible': []}),
})
with gr.Tab('Text generation', elem_id='main'): with gr.Tab('Text generation', elem_id='main'):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat')) shared.gradio['display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': []}, shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat'))
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row(): with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop') shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
@ -586,7 +632,7 @@ def create_interface():
with gr.Row(): with gr.Row():
with gr.Column(scale=8): with gr.Column(scale=8):
with gr.Row(): with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown') shared.gradio['character_menu'] = gr.Dropdown(value='None', choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button') ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button')
shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button') shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button')
shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button') shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button')
@ -648,7 +694,7 @@ def create_interface():
# Create notebook mode interface # Create notebook mode interface
elif shared.args.notebook: elif shared.args.notebook:
shared.input_elements = ui.list_interface_input_elements(chat=False) shared.input_elements = ui.list_interface_input_elements()
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['last_input'] = gr.State('') shared.gradio['last_input'] = gr.State('')
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
@ -687,7 +733,7 @@ def create_interface():
# Create default mode interface # Create default mode interface
else: else:
shared.input_elements = ui.list_interface_input_elements(chat=False) shared.input_elements = ui.list_interface_input_elements()
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['last_input'] = gr.State('') shared.gradio['last_input'] = gr.State('')
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
@ -731,8 +777,8 @@ def create_interface():
with gr.Tab("Training", elem_id="training-tab"): with gr.Tab("Training", elem_id="training-tab"):
training.create_train_interface() training.create_train_interface()
# Interface mode tab # Session tab
with gr.Tab("Interface mode", elem_id="interface-mode"): with gr.Tab("Session", elem_id="session-tab"):
modes = ["default", "notebook", "chat"] modes = ["default", "notebook", "chat"]
current_mode = "default" current_mode = "default"
for mode in modes[1:]: for mode in modes[1:]:
@ -745,9 +791,12 @@ def create_interface():
bool_active = [k for k in bool_list if vars(shared.args)[k]] bool_active = [k for k in bool_list if vars(shared.args)[k]]
with gr.Row(): with gr.Row():
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", elem_classes="small-button") with gr.Column():
shared.gradio['toggle_dark_mode'] = gr.Button('Toggle dark/light mode', elem_classes="small-button") with gr.Row():
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode", elem_classes='slim-dropdown')
shared.gradio['reset_interface'] = gr.Button("Apply and restart", elem_classes="small-button", variant="primary")
shared.gradio['toggle_dark_mode'] = gr.Button('Toggle 💡', elem_classes="small-button")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -756,212 +805,239 @@ def create_interface():
with gr.Column(): with gr.Column():
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags", elem_classes='checkboxgroup-table') shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags", elem_classes='checkboxgroup-table')
with gr.Column():
if not shared.args.multi_user:
with gr.Row(): with gr.Row():
extension_name = gr.Textbox(lines=1, label='Install or update an extension', info='Enter the GitHub URL below. For a list of extensions, see: https://github.com/oobabooga/text-generation-webui-extensions ⚠️ WARNING ⚠️ : extensions can execute arbitrary code. Make sure to inspect their source code before activating them.') shared.gradio['session_menu'] = gr.Dropdown(choices=utils.get_available_sessions(), value='None', label='Session', elem_classes='slim-dropdown', info='When saving a session, make sure to keep the initial part of the filename (session_chat, session_notebook, or session_default), otherwise it will not appear on this list afterwards.')
extension_install = gr.Button('Install or update', elem_classes="small-button") ui.create_refresh_button(shared.gradio['session_menu'], lambda: None, lambda: {'choices': utils.get_available_sessions()}, ['refresh-button'])
shared.gradio['save_session'] = gr.Button('💾', elem_classes=['refresh-button'])
shared.gradio['delete_session'] = gr.Button('🗑️', elem_classes=['refresh-button'])
extension_name = gr.Textbox(lines=1, label='Install or update an extension', info='Enter the GitHub URL below and press Enter. For a list of extensions, see: https://github.com/oobabooga/text-generation-webui-extensions ⚠️ WARNING ⚠️ : extensions can execute arbitrary code. Make sure to inspect their source code before activating them.')
extension_status = gr.Markdown() extension_status = gr.Markdown()
extension_install.click( extension_name.submit(
clone_or_pull_repository, extension_name, extension_status, show_progress=False).then( clone_or_pull_repository, extension_name, extension_status, show_progress=False).then(
lambda: gr.update(choices=utils.get_available_extensions(), value=shared.args.extensions), outputs=shared.gradio['extensions_menu']) lambda: gr.update(choices=utils.get_available_extensions(), value=shared.args.extensions), None, gradio('extensions_menu'))
# Reset interface event # Reset interface event
shared.gradio['reset_interface'].click( shared.gradio['reset_interface'].click(
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then( set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then(
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}') lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}') shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}')
# chat mode event handlers # chat mode event handlers
if shared.is_chat(): if shared.is_chat():
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'start_with', 'interface_state']] shared.input_params = gradio('Chat input', 'start_with', 'interface_state')
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] clear_arr = gradio('Clear history-confirm', 'Clear history', 'Clear history-cancel')
shared.reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode', 'chat_style']] shared.reload_inputs = gradio('history', 'name1', 'name2', 'mode', 'chat_style')
gen_events.append(shared.gradio['Generate'].click( gen_events.append(shared.gradio['Generate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then(
chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
gen_events.append(shared.gradio['textbox'].submit( gen_events.append(shared.gradio['textbox'].submit(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then(
chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
gen_events.append(shared.gradio['Regenerate'].click( gen_events.append(shared.gradio['Regenerate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
partial(chat.generate_chat_reply_wrapper, regenerate=True), shared.input_params, shared.gradio['display'], show_progress=False).then( partial(chat.generate_chat_reply_wrapper, regenerate=True), shared.input_params, gradio('display', 'history'), show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
gen_events.append(shared.gradio['Continue'].click( gen_events.append(shared.gradio['Continue'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
partial(chat.generate_chat_reply_wrapper, _continue=True), shared.input_params, shared.gradio['display'], show_progress=False).then( partial(chat.generate_chat_reply_wrapper, _continue=True), shared.input_params, gradio('display', 'history'), show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
gen_events.append(shared.gradio['Impersonate'].click( gen_events.append(shared.gradio['Impersonate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: x, shared.gradio['textbox'], shared.gradio['Chat input'], show_progress=False).then( lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then(
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False).then( chat.impersonate_wrapper, shared.input_params, gradio('textbox'), show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
shared.gradio['Replace last reply'].click( shared.gradio['Replace last reply'].click(
chat.replace_last_reply, shared.gradio['textbox'], None).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: '', None, shared.gradio['textbox'], show_progress=False).then( chat.replace_last_reply, gradio('textbox', 'interface_state'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( lambda: '', None, gradio('textbox'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['Send dummy message'].click( shared.gradio['Send dummy message'].click(
chat.send_dummy_message, shared.gradio['textbox'], None).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: '', None, shared.gradio['textbox'], show_progress=False).then( chat.send_dummy_message, gradio('textbox', 'interface_state'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( lambda: '', None, gradio('textbox'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['Send dummy reply'].click( shared.gradio['Send dummy reply'].click(
chat.send_dummy_reply, shared.gradio['textbox'], None).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: '', None, shared.gradio['textbox'], show_progress=False).then( chat.send_dummy_reply, gradio('textbox', 'interface_state'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( lambda: '', None, gradio('textbox'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click( shared.gradio['Clear history-confirm'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
chat.clear_chat_log, [shared.gradio[k] for k in ['greeting', 'mode']], None).then( chat.clear_chat_log, gradio('interface_state'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['Remove last'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.remove_last_message, gradio('history'), gradio('textbox', 'history'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['character_menu'].change(
partial(chat.load_character, instruct=False), gradio('character_menu', 'name1', 'name2'), gradio('name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy')).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.load_persistent_history, gradio('interface_state'), gradio('history')).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['Stop'].click( shared.gradio['Stop'].click(
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then( stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['mode'].change( shared.gradio['mode'].change(
lambda x: gr.update(visible=x != 'instruct'), shared.gradio['mode'], shared.gradio['chat_style'], show_progress=False).then( lambda x: gr.update(visible=x != 'instruct'), gradio('mode'), gradio('chat_style'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['chat_style'].change(chat.redraw_html, shared.reload_inputs, shared.gradio['display']) shared.gradio['chat_style'].change(chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['instruction_template'].change( shared.gradio['instruction_template'].change(
partial(chat.load_character, instruct=True), [shared.gradio[k] for k in ['instruction_template', 'name1_instruct', 'name2_instruct']], [shared.gradio[k] for k in ['name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template']]) partial(chat.load_character, instruct=True), gradio('instruction_template', 'name1_instruct', 'name2_instruct'), gradio('name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template'))
shared.gradio['upload_chat_history'].upload( shared.gradio['upload_chat_history'].upload(
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then( chat.load_history, gradio('upload_chat_history', 'history'), gradio('history')).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=False) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, gradio('history'), gradio('textbox'), show_progress=False)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(
chat.remove_last_message, None, shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
# Save/delete a character # Save/delete a character
shared.gradio['save_character'].click( shared.gradio['save_character'].click(
lambda x: x, shared.gradio['name2'], shared.gradio['save_character_filename']).then( lambda x: x, gradio('name2'), gradio('save_character_filename')).then(
lambda: gr.update(visible=True), None, shared.gradio['character_saver']) lambda: gr.update(visible=True), None, gradio('character_saver'))
shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, shared.gradio['character_deleter']) shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'))
shared.gradio['save_template'].click( shared.gradio['save_template'].click(
lambda: 'My Template.yaml', None, shared.gradio['save_filename']).then( lambda: 'My Template.yaml', None, gradio('save_filename')).then(
lambda: 'characters/instruction-following/', None, shared.gradio['save_root']).then( lambda: 'characters/instruction-following/', None, gradio('save_root')).then(
chat.generate_instruction_template_yaml, [shared.gradio[k] for k in ['name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template']], shared.gradio['save_contents']).then( chat.generate_instruction_template_yaml, gradio('name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template'), gradio('save_contents')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver']) lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['delete_template'].click( shared.gradio['delete_template'].click(
lambda x: f'{x}.yaml', shared.gradio['instruction_template'], shared.gradio['delete_filename']).then( lambda x: f'{x}.yaml', gradio('instruction_template'), gradio('delete_filename')).then(
lambda: 'characters/instruction-following/', None, shared.gradio['delete_root']).then( lambda: 'characters/instruction-following/', None, gradio('delete_root')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_deleter']) lambda: gr.update(visible=True), None, gradio('file_deleter'))
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True, user_request=True), shared.gradio['mode'], shared.gradio['download']) shared.gradio['download_button'].click(chat.save_history, gradio('history'), gradio('download'))
shared.gradio['Submit character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['Submit character'].click(chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu'))
shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, [shared.gradio['Submit character']]) shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character'))
shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, [shared.gradio['Submit character']]) shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character'))
shared.gradio['character_menu'].change( shared.gradio['Submit tavern character'].click(chat.upload_tavern_character, gradio('upload_img_tavern', 'tavern_json'), gradio('character_menu'))
partial(chat.load_character, instruct=False), [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy']]).then( shared.gradio['upload_img_tavern'].upload(chat.check_tavern_character, gradio('upload_img_tavern'), gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False)
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) shared.gradio['upload_img_tavern'].clear(lambda: (None, None, None, gr.update(interactive=False)), None, gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False)
shared.gradio['Submit tavern character'].click(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['tavern_json']], [shared.gradio['character_menu']])
shared.gradio['upload_img_tavern'].upload(chat.check_tavern_character, shared.gradio['upload_img_tavern'], [shared.gradio[k] for k in ['tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character']], show_progress=False)
shared.gradio['upload_img_tavern'].clear(lambda: (None, None, None, gr.update(interactive=False)), None, [shared.gradio[k] for k in ['tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character']], show_progress=False)
shared.gradio['your_picture'].change( shared.gradio['your_picture'].change(
chat.upload_your_profile_picture, shared.gradio['your_picture'], None).then( chat.upload_your_profile_picture, gradio('your_picture'), None).then(
partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, shared.gradio['display']) partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, gradio('display'))
# notebook/default modes event handlers # notebook/default modes event handlers
else: else:
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] shared.input_params = gradio('textbox', 'interface_state')
if shared.args.notebook: if shared.args.notebook:
output_params = [shared.gradio[k] for k in ['textbox', 'html']] output_params = gradio('textbox', 'html')
else: else:
output_params = [shared.gradio[k] for k in ['output_textbox', 'html']] output_params = gradio('output_textbox', 'html')
gen_events.append(shared.gradio['Generate'].click( gen_events.append(shared.gradio['Generate'].click(
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then( lambda x: x, gradio('textbox'), gradio('last_input')).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
) )
gen_events.append(shared.gradio['textbox'].submit( gen_events.append(shared.gradio['textbox'].submit(
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then( lambda x: x, gradio('textbox'), gradio('last_input')).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
) )
if shared.args.notebook: if shared.args.notebook:
shared.gradio['Undo'].click(lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False) shared.gradio['Undo'].click(lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False)
shared.gradio['markdown_render'].click(lambda x: x, shared.gradio['textbox'], shared.gradio['markdown'], queue=False) shared.gradio['markdown_render'].click(lambda x: x, gradio('textbox'), gradio('markdown'), queue=False)
gen_events.append(shared.gradio['Regenerate'].click( gen_events.append(shared.gradio['Regenerate'].click(
lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False).then( lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
) )
else: else:
shared.gradio['markdown_render'].click(lambda x: x, shared.gradio['output_textbox'], shared.gradio['markdown'], queue=False) shared.gradio['markdown_render'].click(lambda x: x, gradio('output_textbox'), gradio('markdown'), queue=False)
gen_events.append(shared.gradio['Continue'].click( gen_events.append(shared.gradio['Continue'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate_reply_wrapper, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False).then( generate_reply_wrapper, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
) )
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False) shared.gradio['prompt_menu'].change(load_prompt, gradio('prompt_menu'), gradio('textbox'), show_progress=False)
shared.gradio['save_prompt'].click( shared.gradio['save_prompt'].click(
lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then( lambda x: x, gradio('textbox'), gradio('save_contents')).then(
lambda: 'prompts/', None, shared.gradio['save_root']).then( lambda: 'prompts/', None, gradio('save_root')).then(
lambda: utils.current_time() + '.txt', None, shared.gradio['save_filename']).then( lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver']) lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['delete_prompt'].click( shared.gradio['delete_prompt'].click(
lambda: 'prompts/', None, shared.gradio['delete_root']).then( lambda: 'prompts/', None, gradio('delete_root')).then(
lambda x: x + '.txt', shared.gradio['prompt_menu'], shared.gradio['delete_filename']).then( lambda x: x + '.txt', gradio('prompt_menu'), gradio('delete_filename')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_deleter']) lambda: gr.update(visible=True), None, gradio('file_deleter'))
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['count_tokens'].click(count_tokens, gradio('textbox'), gradio('status'), show_progress=False)
create_file_saving_event_handlers() create_file_saving_event_handlers()
shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}")
shared.gradio['interface'].load(
lambda: None, None, None, _js=f"() => {{{js}}}").then(
partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False)
if shared.settings['dark_theme']: if shared.settings['dark_theme']:
shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')") shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) if shared.is_chat():
shared.gradio['interface'].load(chat.redraw_html, shared.reload_inputs, gradio('display'))
# Extensions tabs # Extensions tabs
extensions_module.create_extensions_tabs() extensions_module.create_extensions_tabs()
@ -1058,7 +1134,11 @@ if __name__ == "__main__":
if shared.args.lora: if shared.args.lora:
add_lora_to_model(shared.args.lora) add_lora_to_model(shared.args.lora)
# Force a character to be loaded # Forcing some events to be triggered on page load
shared.persistent_interface_state.update({
'loader': shared.args.loader or 'Transformers',
})
if shared.is_chat(): if shared.is_chat():
shared.persistent_interface_state.update({ shared.persistent_interface_state.update({
'mode': shared.settings['mode'], 'mode': shared.settings['mode'],
@ -1066,11 +1146,11 @@ if __name__ == "__main__":
'instruction_template': shared.settings['instruction_template'] 'instruction_template': shared.settings['instruction_template']
}) })
shared.persistent_interface_state.update({ if Path("cache/pfp_character.png").exists():
'loader': shared.args.loader or 'Transformers', Path("cache/pfp_character.png").unlink()
})
shared.generation_lock = Lock() shared.generation_lock = Lock()
# Launch the web UI # Launch the web UI
create_interface() create_interface()
while True: while True: