diff --git a/modules/chat.py b/modules/chat.py index fd1812b8..f6b4c323 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -152,6 +152,7 @@ def get_stopping_strings(state): def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True): output = copy.deepcopy(history) output = apply_extensions('history', output) + state = apply_extensions('state', state) if shared.model_name == 'None' or shared.model is None: logger.error("No model is loaded! Select one in the Model tab.") yield output @@ -161,6 +162,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa just_started = True visible_text = None stopping_strings = get_stopping_strings(state) + is_stream = state['stream'] # Preparing the input if not any((regenerate, _continue)): @@ -204,11 +206,11 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa # Extract the reply visible_reply = re.sub("(||{{user}})", state['name1'], reply) - visible_reply = apply_extensions("output", visible_reply) # We need this global variable to handle the Stop event, # otherwise gradio gets confused if shared.stop_everything: + output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1]) yield output return @@ -221,12 +223,12 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa if _continue: output['internal'][-1] = [text, last_reply[0] + reply] output['visible'][-1] = [visible_text, last_reply[1] + visible_reply] - if state['stream']: + if is_stream: yield output elif not (j == 0 and visible_reply.strip() == ''): output['internal'][-1] = [text, reply.lstrip(' ')] output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] - if state['stream']: + if is_stream: yield output if reply in [None, cumulative_reply]: @@ -234,6 +236,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa else: cumulative_reply = reply + output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1]) yield output diff --git a/modules/text_generation.py b/modules/text_generation.py index 81ada7ee..78d74ed7 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -103,9 +103,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'): reply = ' ' + reply - if not is_chat: - reply = apply_extensions('output', reply) - return reply @@ -170,7 +167,6 @@ def apply_stopping_strings(reply, all_stop_strings): def _generate_reply(question, state, stopping_strings=None, is_chat=False): - state = apply_extensions('state', state) generate_func = apply_extensions('custom_generate_reply') if generate_func is None: if shared.model_name == 'None' or shared.model is None: @@ -188,6 +184,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False): # Preparing the input original_question = question if not is_chat: + state = apply_extensions('state', state) question = apply_extensions('input', question) # Finding the stopping strings @@ -219,6 +216,9 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False): if stop_found: break + if not is_chat: + reply = apply_extensions('output', reply) + yield reply @@ -311,15 +311,9 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str if not state['stream']: reply = shared.model.generate(question, state) - if not is_chat: - reply = apply_extensions('output', reply) - yield reply else: for reply in shared.model.generate_with_streaming(question, state): - if not is_chat: - reply = apply_extensions('output', reply) - yield reply except Exception: