Fix lag in the chat tab during streaming

This commit is contained in:
oobabooga 2023-12-12 13:00:38 -08:00
parent 736fe4aa3e
commit 8513028968
2 changed files with 8 additions and 8 deletions

View File

@ -205,7 +205,7 @@ def get_stopping_strings(state):
return list(set(stopping_strings)) return list(set(stopping_strings))
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True): def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
history = state['history'] history = state['history']
output = copy.deepcopy(history) output = copy.deepcopy(history)
output = apply_extensions('history', output) output = apply_extensions('history', output)
@ -256,7 +256,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
# Generate # Generate
reply = None reply = None
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)): for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True, for_ui=for_ui)):
# Extract the reply # Extract the reply
visible_reply = reply visible_reply = reply
@ -311,7 +311,7 @@ def impersonate_wrapper(text, state):
return return
def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True): def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
history = state['history'] history = state['history']
if regenerate or _continue: if regenerate or _continue:
text = '' text = ''
@ -319,7 +319,7 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_
yield history yield history
return return
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message): for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui):
yield history yield history
@ -351,7 +351,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
send_dummy_message(text, state) send_dummy_message(text, state)
send_dummy_reply(state['start_with'], state) send_dummy_reply(state['start_with'], state)
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True)): for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history

View File

@ -33,7 +33,7 @@ def generate_reply(*args, **kwargs):
shared.generation_lock.release() shared.generation_lock.release()
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False): def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False):
# Find the appropriate generation function # Find the appropriate generation function
generate_func = apply_extensions('custom_generate_reply') generate_func = apply_extensions('custom_generate_reply')
@ -96,7 +96,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
# Limit updates to 24 or 5 per second to avoid lag in the Gradio UI # Limit updates to 24 or 5 per second to avoid lag in the Gradio UI
# API updates are not limited # API updates are not limited
else: else:
min_update_interval = 0 if not escape_html else 0.2 if (shared.args.listen or shared.args.share) else 0.0417 min_update_interval = 0 if not for_ui else 0.2 if (shared.args.listen or shared.args.share) else 0.0417
if cur_time - last_update > min_update_interval: if cur_time - last_update > min_update_interval:
last_update = cur_time last_update = cur_time
yield reply yield reply
@ -178,7 +178,7 @@ def generate_reply_wrapper(question, state, stopping_strings=None):
reply = question if not shared.is_seq2seq else '' reply = question if not shared.is_seq2seq else ''
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True): for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True, for_ui=True):
if not shared.is_seq2seq: if not shared.is_seq2seq:
reply = question + reply reply = question + reply