From 13f26881344ac56a9724cdfd8c1e494c75425818 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Feb 2023 01:08:17 -0300 Subject: [PATCH] Better way to generate custom prompts --- extensions/send_pictures/script.py | 7 +----- modules/chat.py | 36 ++++++++++-------------------- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index 6dc35e2f..c6321034 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -14,12 +14,7 @@ params = { # custom output text input_hijack = { 'state': 'off', - 'value': [] -} - -prompt_hijack = { - 'state': 'off', - 'value': "" + 'value': ["", ""] } def generate_chat_picture(picture, name1, name2): diff --git a/modules/chat.py b/modules/chat.py index 03dae071..9ef3bb15 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -86,33 +86,19 @@ def stop_everything_event(): def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): shared.stop_everything = False + just_started = True + eos_token = '\n' if check else None + if 'pygmalion' in shared.model_name.lower(): + name1 = "You" # Check if any extension wants to hijack this function call visible_text = None - prompt = None + custom_prompt_generator = None for extension, _ in extensions_module.iterator(): - if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] in ['temporary', 'permanent']: - if extension.input_hijack['state'] == 'temporary': - extension.input_hijack['state'] = 'off' - values = extension.input_hijack['value'] - if len(values) == 2: - text, visible_text = values - elif len(values) == 4: - text, visible_text, reply, visible_reply = valueso - if not shared.stop_everything: - shared.history['internal'].append([text, reply]) - shared.history['visible'].append([visible_text, visible_reply]) - return shared.history['visible'] - if hasattr(extension, 'prompt_hijack') and extension.prompt_hijack['state'] in ['temporary', 'permanent']: - if extension.prompt_hijack['state'] == 'temporary': - extension.prompt_hijack['state'] = 'off' - prompt = extension.prompt_hijack['value'] - - just_started = True - eos_token = '\n' if check else None - - if 'pygmalion' in shared.model_name.lower(): - name1 = "You" + if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: + text, visible_text = extension.input_hijack['value'] + if custom_prompt_generator is None and hasattr(extension, 'custom_prompt_generator'): + custom_prompt_generator = extension.custom_prompt_generator if visible_text is None: visible_text = text @@ -120,8 +106,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical visible_text = visible_text.replace('\n', '
') text = apply_extensions(text, "input") - if prompt is None: + if custom_prompt_generator is None: prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + else: + prompt = custom_prompt_generator(text, max_new_tokens, name1, name2, context, chat_prompt_size) # Generate for reply in generate_reply(prompt, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):