From a4e903e932c6b3b43b2ccb88f9e75049b2ac4b2e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 16 Aug 2023 09:23:29 -0700 Subject: [PATCH] Escape HTML in chat messages --- modules/chat.py | 25 +++++++++++++------------ modules/html_generator.py | 2 ++ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index d83e9490..d81d254f 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -1,6 +1,7 @@ import base64 import copy import functools +import html import json import re from pathlib import Path @@ -188,15 +189,16 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess yield output return - # Defining some variables just_started = True visible_text = None stopping_strings = get_stopping_strings(state) is_stream = state['stream'] - # Preparing the input + # Prepare the input if not any((regenerate, _continue)): - visible_text = text + visible_text = html.escape(text) + + # Apply extensions text, visible_text = apply_extensions('chat_input', text, visible_text, state) text = apply_extensions('input', text, state, is_chat=True) @@ -208,6 +210,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess if regenerate: output['visible'].pop() output['internal'].pop() + # *Is typing...* if loading_message: yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} @@ -216,12 +219,11 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess if loading_message: yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']} - # Generating the prompt + # Generate the prompt kwargs = { '_continue': _continue, 'history': output, } - prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs) if prompt is None: prompt = generate_chat_prompt(text, state, **kwargs) @@ -232,9 +234,8 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess # Extract the reply visible_reply = re.sub("(||{{user}})", state['name1'], reply) + visible_reply = html.escape(visible_reply) - # We need this global variable to handle the Stop event, - # otherwise gradio gets confused if shared.stop_everything: output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True) yield output @@ -307,8 +308,8 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): def remove_last_message(history): if len(history['visible']) > 0 and history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': - last = history['visible'].pop() - history['internal'].pop() + last = history['internal'].pop() + history['visible'].pop() else: last = ['', ''] @@ -328,7 +329,7 @@ def replace_last_reply(text, state): if len(text.strip()) == 0: return history elif len(history['visible']) > 0: - history['visible'][-1][1] = text + history['visible'][-1][1] = html.escape(text) history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True) return history @@ -336,7 +337,7 @@ def replace_last_reply(text, state): def send_dummy_message(text, state): history = state['history'] - history['visible'].append([text, '']) + history['visible'].append([html.escape(text), '']) history['internal'].append([apply_extensions('input', text, state, is_chat=True), '']) return history @@ -347,7 +348,7 @@ def send_dummy_reply(text, state): history['visible'].append(['', '']) history['internal'].append(['', '']) - history['visible'][-1][1] = text + history['visible'][-1][1] = html.escape(text) history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True) return history diff --git a/modules/html_generator.py b/modules/html_generator.py index eb1da374..3d9f758b 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -1,3 +1,4 @@ +import html import os import re import time @@ -85,6 +86,7 @@ def convert_to_markdown(string): def generate_basic_html(string): + string = html.escape(string) string = convert_to_markdown(string) string = f'
{string}
' return string