mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix lag in the chat tab during streaming
This commit is contained in:
parent
736fe4aa3e
commit
8513028968
@ -205,7 +205,7 @@ def get_stopping_strings(state):
|
|||||||
return list(set(stopping_strings))
|
return list(set(stopping_strings))
|
||||||
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True):
|
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
|
||||||
history = state['history']
|
history = state['history']
|
||||||
output = copy.deepcopy(history)
|
output = copy.deepcopy(history)
|
||||||
output = apply_extensions('history', output)
|
output = apply_extensions('history', output)
|
||||||
@ -256,7 +256,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
reply = None
|
reply = None
|
||||||
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)):
|
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True, for_ui=for_ui)):
|
||||||
|
|
||||||
# Extract the reply
|
# Extract the reply
|
||||||
visible_reply = reply
|
visible_reply = reply
|
||||||
@ -311,7 +311,7 @@ def impersonate_wrapper(text, state):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True):
|
def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
|
||||||
history = state['history']
|
history = state['history']
|
||||||
if regenerate or _continue:
|
if regenerate or _continue:
|
||||||
text = ''
|
text = ''
|
||||||
@ -319,7 +319,7 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_
|
|||||||
yield history
|
yield history
|
||||||
return
|
return
|
||||||
|
|
||||||
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
|
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui):
|
||||||
yield history
|
yield history
|
||||||
|
|
||||||
|
|
||||||
@ -351,7 +351,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||||||
send_dummy_message(text, state)
|
send_dummy_message(text, state)
|
||||||
send_dummy_reply(state['start_with'], state)
|
send_dummy_reply(state['start_with'], state)
|
||||||
|
|
||||||
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True)):
|
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
|
||||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
|
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def generate_reply(*args, **kwargs):
|
|||||||
shared.generation_lock.release()
|
shared.generation_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False):
|
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False):
|
||||||
|
|
||||||
# Find the appropriate generation function
|
# Find the appropriate generation function
|
||||||
generate_func = apply_extensions('custom_generate_reply')
|
generate_func = apply_extensions('custom_generate_reply')
|
||||||
@ -96,7 +96,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||||||
# Limit updates to 24 or 5 per second to avoid lag in the Gradio UI
|
# Limit updates to 24 or 5 per second to avoid lag in the Gradio UI
|
||||||
# API updates are not limited
|
# API updates are not limited
|
||||||
else:
|
else:
|
||||||
min_update_interval = 0 if not escape_html else 0.2 if (shared.args.listen or shared.args.share) else 0.0417
|
min_update_interval = 0 if not for_ui else 0.2 if (shared.args.listen or shared.args.share) else 0.0417
|
||||||
if cur_time - last_update > min_update_interval:
|
if cur_time - last_update > min_update_interval:
|
||||||
last_update = cur_time
|
last_update = cur_time
|
||||||
yield reply
|
yield reply
|
||||||
@ -178,7 +178,7 @@ def generate_reply_wrapper(question, state, stopping_strings=None):
|
|||||||
reply = question if not shared.is_seq2seq else ''
|
reply = question if not shared.is_seq2seq else ''
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True):
|
for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True, for_ui=True):
|
||||||
if not shared.is_seq2seq:
|
if not shared.is_seq2seq:
|
||||||
reply = question + reply
|
reply = question + reply
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user