From 70845c76fb7a7db4de8c14ba9e9aaa8e3ef51955 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 26 Apr 2024 10:14:51 -0300 Subject: [PATCH] Add back the max_updates_second parameter (#5937) --- modules/shared.py | 1 + modules/text_generation.py | 11 +++++++++++ modules/ui.py | 1 + modules/ui_parameters.py | 1 + settings-template.yaml | 1 + 5 files changed, 15 insertions(+) diff --git a/modules/shared.py b/modules/shared.py index 0c0db858..53f2e52b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -46,6 +46,7 @@ settings = { 'truncation_length_min': 0, 'truncation_length_max': 200000, 'max_tokens_second': 0, + 'max_updates_second': 0, 'prompt_lookup_num_tokens': 0, 'custom_stopping_strings': '', 'custom_token_bans': '', diff --git a/modules/text_generation.py b/modules/text_generation.py index 8b422399..17ec6bd8 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -84,6 +84,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap state = copy.deepcopy(state) state['stream'] = True + min_update_interval = 0 + if state.get('max_updates_second', 0) > 0: + min_update_interval = 1 / state['max_updates_second'] + # Generate for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): reply, stop_found = apply_stopping_strings(reply, all_stop_strings) @@ -101,7 +105,14 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap last_update = time.time() yield reply + + # Limit updates to avoid lag in the Gradio UI + # API updates are not limited else: + if cur_time - last_update > min_update_interval: + last_update = cur_time + yield reply + yield reply if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything): diff --git a/modules/ui.py b/modules/ui.py index fc507af7..f9a04522 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -122,6 +122,7 @@ def list_interface_input_elements(): 'max_new_tokens', 'auto_max_new_tokens', 'max_tokens_second', + 'max_updates_second', 'prompt_lookup_num_tokens', 'seed', 'temperature', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index b3adc570..0414cdde 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -85,6 +85,7 @@ def create_ui(default_preset): shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') shared.gradio['prompt_lookup_num_tokens'] = gr.Slider(value=shared.settings['prompt_lookup_num_tokens'], minimum=0, maximum=10, step=1, label='prompt_lookup_num_tokens', info='Activates Prompt Lookup Decoding.') shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum tokens/second', info='To make text readable in real time.') + shared.gradio['max_updates_second'] = gr.Slider(value=shared.settings['max_updates_second'], minimum=0, maximum=24, step=1, label='Maximum UI updates/second', info='Set this if you experience lag in the UI during streaming.') shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') diff --git a/settings-template.yaml b/settings-template.yaml index a89b6282..f09c845e 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -15,6 +15,7 @@ truncation_length: 2048 truncation_length_min: 0 truncation_length_max: 200000 max_tokens_second: 0 +max_updates_second: 0 prompt_lookup_num_tokens: 0 custom_stopping_strings: '' custom_token_bans: ''