diff --git a/modules/chat.py b/modules/chat.py index de7f19de..7f25254c 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -82,7 +82,11 @@ def generate_chat_prompt(user_input, state, **kwargs): history = kwargs.get('history', state['history'])['internal'] # Templates - chat_template = jinja_env.from_string(state['chat_template_str']) + chat_template_str = state['chat_template_str'] + if state['mode'] != 'instruct': + chat_template_str = replace_character_names(chat_template_str, state['name1'], state['name2']) + + chat_template = jinja_env.from_string(chat_template_str) instruction_template = jinja_env.from_string(state['instruction_template_str']) chat_renderer = partial(chat_template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2']) instruct_renderer = partial(instruction_template.render, add_generation_prompt=False)