diff --git a/modules/chat.py b/modules/chat.py index 2048e2c5..f727ce8b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -22,6 +22,12 @@ def clean_chat_message(text): text = text.strip() return text +def generate_chat_output(history, name1, name2, character): + if shared.args.cai_chat: + return generate_chat_html(history, name1, name2, character) + else: + return history + def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False): user_input = clean_chat_message(user_input) rows = [f"{context.strip()}\n"] @@ -182,21 +188,18 @@ def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: - if shared.args.cai_chat: - yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) - else: - yield shared.history['visible'] + yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) else: last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() + yield generate_chat_output(shared.history['visible']+[[last_visible[0], '*Is typing...*']], name1, name2, shared.character) for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] - yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) else: shared.history['visible'][-1] = (last_visible[0], _history[-1][1]) - yield shared.history['visible'] + yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) def remove_last_message(name1, name2): if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': @@ -204,6 +207,7 @@ def remove_last_message(name1, name2): shared.history['internal'].pop() else: last = ['', ''] + if shared.args.cai_chat: return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0] else: @@ -223,10 +227,7 @@ def replace_last_reply(text, name1, name2): shared.history['visible'][-1] = (shared.history['visible'][-1][0], text) shared.history['internal'][-1][1] = apply_extensions(text, "input") - if shared.args.cai_chat: - return generate_chat_html(shared.history['visible'], name1, name2, shared.character) - else: - return shared.history['visible'] + return generate_chat_output(shared.history['visible'], name1, name2, shared.character) def clear_html(): return generate_chat_html([], "", "", shared.character) @@ -246,10 +247,8 @@ def clear_chat_log(name1, name2): else: shared.history['internal'] = [] shared.history['visible'] = [] - if shared.args.cai_chat: - return generate_chat_html(shared.history['visible'], name1, name2, shared.character) - else: - return shared.history['visible'] + + return generate_chat_output(shared.history['visible'], name1, name2, shared.character) def redraw_html(name1, name2): return generate_chat_html(shared.history['visible'], name1, name2, shared.character)