diff --git a/modules/chat.py b/modules/chat.py index ebb9ef36..35b2eebb 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -95,7 +95,18 @@ def generate_chat_prompt(user_input, state, **kwargs): def get_stopping_strings(state): if state['mode'] == 'instruct': - stopping_strings = [f"\n{state['name1_instruct']}", f"\n{state['name2_instruct']}"] + stopping_strings = [ + state['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0] + '<|bot|>', + state['turn_template'].split('<|bot-message|>')[1] + '<|user|>' + ] + + replacements = { + '<|user|>': state['name1_instruct'], + '<|bot|>': state['name2_instruct'] + } + + for i in range(len(stopping_strings)): + stopping_strings[i] = replace_all(stopping_strings[i], replacements).rstrip(' ').replace(r'\n', '\n') else: stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]