Fix HTML escaping for perplexity_colors extension

This commit is contained in:
oobabooga 2023-08-20 21:40:22 -07:00
parent 6394fef1db
commit a74dd9003f
2 changed files with 6 additions and 3 deletions

View File

@ -89,7 +89,6 @@ def convert_to_markdown(string):
def generate_basic_html(string): def generate_basic_html(string):
string = html.escape(string)
string = convert_to_markdown(string) string = convert_to_markdown(string)
string = f'<style>{readable_css}</style><div class="container">{string}</div>' string = f'<style>{readable_css}</style><div class="container">{string}</div>'
return string return string

View File

@ -1,5 +1,6 @@
import ast import ast
import copy import copy
import html
import random import random
import re import re
import time import time
@ -31,7 +32,7 @@ def generate_reply(*args, **kwargs):
shared.generation_lock.release() shared.generation_lock.release()
def _generate_reply(question, state, stopping_strings=None, is_chat=False): def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False):
# Find the appropriate generation function # Find the appropriate generation function
generate_func = apply_extensions('custom_generate_reply') generate_func = apply_extensions('custom_generate_reply')
@ -73,6 +74,9 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
# Generate # Generate
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
if escape_html:
reply = html.escape(reply)
reply, stop_found = apply_stopping_strings(reply, all_stop_strings) reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
if is_stream: if is_stream:
cur_time = time.time() cur_time = time.time()
@ -138,7 +142,7 @@ def generate_reply_wrapper(question, state, stopping_strings=None):
reply = question if not shared.is_seq2seq else '' reply = question if not shared.is_seq2seq else ''
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
for reply in generate_reply(question, state, stopping_strings, is_chat=False): for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True):
if not shared.is_seq2seq: if not shared.is_seq2seq:
reply = question + reply reply = question + reply