diff --git a/modules/chat.py b/modules/chat.py index 3f313db2..36265990 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -11,18 +11,11 @@ from PIL import Image import modules.extensions as extensions_module import modules.shared as shared from modules.extensions import apply_extensions -from modules.html_generator import generate_chat_html +from modules.html_generator import fix_newlines, generate_chat_html from modules.text_generation import (encode, generate_reply, get_max_prompt_length) -# This gets the new line characters right. -def clean_chat_message(text): - text = text.replace('\n', '\n\n') - text = re.sub(r"\n{3,}", "\n\n", text) - text = text.strip() - return text - def generate_chat_output(history, name1, name2, character): if shared.args.cai_chat: return generate_chat_html(history, name1, name2, character) @@ -30,7 +23,7 @@ def generate_chat_output(history, name1, name2, character): return history def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False): - user_input = clean_chat_message(user_input) + user_input = fix_newlines(user_input) rows = [f"{context.strip()}\n"] if shared.soft_prompt: @@ -83,7 +76,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate if idx != -1: reply = reply[:idx] next_character_found = True - reply = clean_chat_message(reply) + reply = fix_newlines(reply) # If something like "\nYo" is generated just before "\nYou:" # is completed, trim it diff --git a/modules/html_generator.py b/modules/html_generator.py index 9942e6c9..940d5486 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -1,6 +1,6 @@ ''' -This is a library for formatting GPT-4chan and chat outputs as nice HTML. +This is a library for formatting text outputs as nice HTML. ''' @@ -21,10 +21,26 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f: cai_css = f.read() -def generate_basic_html(s): - s = '\n'.join([f'

{line}

' for line in s.split('\n')]) - s = f'
{s}
' - return s +def fix_newlines(string): + string = string.replace('\n', '\n\n') + string = re.sub(r"\n{3,}", "\n\n", string) + string = string.strip() + return string + +# This could probably be generalized and improved +def convert_to_markdown(string): + string = string.replace('\\begin{code}', '```') + string = string.replace('\\end{code}', '```') + string = string.replace('\\begin{blockquote}', '> ') + string = string.replace('\\end{blockquote}', '') + string = re.sub(r"(.)```", r"\1\n```", string) +# string = fix_newlines(string) + return markdown.markdown(string, extensions=['fenced_code']) + +def generate_basic_html(string): + string = convert_to_markdown(string) + string = f'
{string}
' + return string def process_post(post, c): t = post.split('\n') @@ -108,7 +124,7 @@ def generate_chat_html(history, name1, name2, character): img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"]) for i,_row in enumerate(history[::-1]): - row = [markdown.markdown(re.sub(r"(.)```", r"\1\n```", entry), extensions=['fenced_code']) for entry in _row] + row = [convert_to_markdown(entry) for entry in _row] output += f"""