From 111b5d42e71a859458bca979836c4f58773bd977 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Feb 2023 00:49:18 -0300 Subject: [PATCH] Add prompt hijack option for extensions --- extensions/send_pictures/script.py | 5 +++++ modules/chat.py | 30 ++++++++++++++++++++++-------- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index 88ac1797..2a7464b4 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -16,6 +16,11 @@ input_hijack = { 'value': ["", ""] } +prompt_hijack = { + 'state': False, + 'value': "" +} + def generate_chat_picture(picture, name1, name2): text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' buffer = BytesIO() diff --git a/modules/chat.py b/modules/chat.py index 140d2ed8..3d43ad44 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -86,25 +86,39 @@ def stop_everything_event(): def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): shared.stop_everything = False + + # Check if any extension wants to hijack this function call + visible_text = None + prompt = None + for extension, _ in extensions_module.iterator(): + if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: + extension.input_hijack['state'] = False + values = extension.input_hijack['value'] + if len(values) == 2: + text, visible_text = values + elif len(values) == 4: + text, visible_text, reply, visible_reply = valueso + if not shared.stop_everything: + shared.history['internal'].append([text, reply]) + shared.history['visible'].append([visible_text, visible_reply]) + return shared.history['visible'] + if hasattr(extension, 'prompt_hijack') and extension.prompt_hijack['state'] == True: + prompt = extension.prompt_hijack['value'] + just_started = True eos_token = '\n' if check else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" - # Hijacking the input using an extension - visible_text = None - for extension, _ in extensions_module.iterator(): - if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: - text, visible_text = extension.input_hijack['value'] - extension.input_hijack['state'] = False - if visible_text is None: visible_text = text if shared.args.chat: visible_text = visible_text.replace('\n', '
') text = apply_extensions(text, "input") - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + + if prompt is None: + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) # Generate for reply in generate_reply(prompt, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):