From 67623a52b7b2d9c6bbd7385d5dfd19c3db2eb3f3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Feb 2023 00:55:19 -0300 Subject: [PATCH] Allow for permanent hijacking --- extensions/send_pictures/script.py | 11 ++++++----- modules/chat.py | 9 ++++++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index 2a7464b4..6dc35e2f 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -9,15 +9,16 @@ from modules.bot_picture import caption_image params = { } -# If 'state' is True, will hijack the next chatbot wrapper call -# with a custom input text +# If 'state' is 'temporary' or 'permanent', will hijack the next +# chatbot wrapper call with a custom input text and optionally +# custom output text input_hijack = { - 'state': False, - 'value': ["", ""] + 'state': 'off', + 'value': [] } prompt_hijack = { - 'state': False, + 'state': 'off', 'value': "" } diff --git a/modules/chat.py b/modules/chat.py index 3d43ad44..03dae071 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -91,8 +91,9 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical visible_text = None prompt = None for extension, _ in extensions_module.iterator(): - if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: - extension.input_hijack['state'] = False + if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] in ['temporary', 'permanent']: + if extension.input_hijack['state'] == 'temporary': + extension.input_hijack['state'] = 'off' values = extension.input_hijack['value'] if len(values) == 2: text, visible_text = values @@ -102,7 +103,9 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical shared.history['internal'].append([text, reply]) shared.history['visible'].append([visible_text, visible_reply]) return shared.history['visible'] - if hasattr(extension, 'prompt_hijack') and extension.prompt_hijack['state'] == True: + if hasattr(extension, 'prompt_hijack') and extension.prompt_hijack['state'] in ['temporary', 'permanent']: + if extension.prompt_hijack['state'] == 'temporary': + extension.prompt_hijack['state'] = 'off' prompt = extension.prompt_hijack['value'] just_started = True