From 08448fb63776bf7207a69afdbf83313b753310ef Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 3 Apr 2023 01:02:11 -0300 Subject: [PATCH] Rename a variable --- modules/chat.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index ae478f08..73a3814d 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -58,10 +58,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat else: return prompt -def extract_message_from_reply(reply, name1, name2, check): +def extract_message_from_reply(reply, name1, name2, stop_at_newline): next_character_found = False - if check: + if stop_at_newline: lines = reply.split('\n') reply = lines[0].strip() if len(lines) > 1: @@ -85,9 +85,9 @@ def extract_message_from_reply(reply, name1, name2, check): reply = fix_newlines(reply) return reply, next_character_found -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False): just_started = True - eos_token = '\n' if check else None + eos_token = '\n' if stop_at_newline else None name1_original = name1 if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -125,7 +125,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical reply = cumulative_reply + reply # Extracting the reply - reply, next_character_found = extract_message_from_reply(reply, name1, name2, check) + reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) visible_reply = re.sub("(||{{user}})", name1_original, reply) visible_reply = apply_extensions(visible_reply, "output") if shared.args.chat: @@ -152,8 +152,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical yield shared.history['visible'] -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): - eos_token = '\n' if check else None +def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): + eos_token = '\n' if stop_at_newline else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -168,7 +168,7 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply - reply, next_character_found = extract_message_from_reply(reply, name1, name2, check) + reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) yield reply if next_character_found: break @@ -178,11 +178,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): - for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): +def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): + for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts): yield generate_chat_html(_history, name1, name2, shared.character) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): +def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) else: @@ -190,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character) - for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): + for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] else: