From af65c129008c7b84a933247867673575de07ae33 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 27 Mar 2023 13:23:59 -0300 Subject: [PATCH] Change Stop button behavior --- modules/callbacks.py | 2 +- modules/chat.py | 4 ---- modules/text_generation.py | 4 ++++ server.py | 9 +++++---- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/modules/callbacks.py b/modules/callbacks.py index 8d30d615..d85f406d 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -54,7 +54,7 @@ class Iteratorize: self.stop_now = False def _callback(val): - if self.stop_now: + if self.stop_now or shared.stop_everything: raise ValueError self.q.put(val) diff --git a/modules/chat.py b/modules/chat.py index 1a43cf3d..cc3c45c7 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check): reply = fix_newlines(reply) return reply, next_character_found -def stop_everything_event(): - shared.stop_everything = True - def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): - shared.stop_everything = False just_started = True eos_token = '\n' if check else None name1_original = name1 diff --git a/modules/text_generation.py b/modules/text_generation.py index 9b2c233d..477257c2 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -99,9 +99,13 @@ def set_manual_seed(seed): if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) +def stop_everything_event(): + shared.stop_everything = True + def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]): clear_torch_cache() set_manual_seed(seed) + shared.stop_everything = False t0 = time.time() original_question = question diff --git a/server.py b/server.py index 020093ee..9f90c79b 100644 --- a/server.py +++ b/server.py @@ -16,7 +16,8 @@ import modules.ui as ui from modules.html_generator import generate_chat_html from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt -from modules.text_generation import clear_torch_cache, generate_reply +from modules.text_generation import (clear_torch_cache, generate_reply, + stop_everything_event) # Loading custom settings settings_file = None @@ -366,7 +367,7 @@ def create_interface(): gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) @@ -437,7 +438,7 @@ def create_interface(): output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") else: @@ -471,7 +472,7 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") with gr.Tab("Interface mode", elem_id="interface-mode"):