Fix loading session in chat mode

This commit is contained in:
oobabooga 2023-08-02 21:13:16 -07:00
parent 4b6c1d3f08
commit 32c564509e
3 changed files with 22 additions and 8 deletions

View File

@ -418,6 +418,10 @@ def save_persistent_history(history, character, mode):
def load_persistent_history(state):
if shared.session_is_loading:
shared.session_is_loading = False
return state['history']
if state['mode'] == 'instruct':
return state['history']

View File

@ -30,6 +30,10 @@ reload_inputs = [] # Parameters for reloading the chat interface
# For restarting the interface
need_restart = False
# To prevent the persistent chat history from being loaded when
# a session JSON file is being loaded in chat mode
session_is_loading = False
settings = {
'dark_theme': True,
'autoload_model': False,

View File

@ -511,11 +511,11 @@ def create_file_saving_event_handlers():
def load_session(file, state):
decoded_file = file if type(file) == str else file.decode('utf-8')
data = json.loads(decoded_file)
if shared.is_chat() and 'character_menu' in data and state.get('character_menu') != data.get('character_menu'):
shared.session_is_loading = True
state.update(data)
if shared.is_chat():
chat.save_persistent_history(state['history'], state['character_menu'], state['mode'])
return state
shared.gradio['save_session'].click(
@ -523,6 +523,12 @@ def create_file_saving_event_handlers():
lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('temporary_text')).then(
None, gradio('temporary_text'), None, _js=f"(contents) => {{{ui.save_files_js}; saveSession(contents, \"{shared.get_mode()}\")}}")
if shared.is_chat():
shared.gradio['load_session'].upload(
load_session, gradio('load_session', 'interface_state'), gradio('interface_state')).then(
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
else:
shared.gradio['load_session'].upload(
load_session, gradio('load_session', 'interface_state'), gradio('interface_state')).then(
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False)