diff --git a/modules/chat.py b/modules/chat.py index 5d2bdd63..920c0f7b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -3,6 +3,7 @@ import copy import functools import html import json +import pprint import re from datetime import datetime from functools import partial @@ -259,10 +260,27 @@ def get_stopping_strings(state): suffix_bot + prefix_user, ] + # Try to find the EOT token + for item in stopping_strings.copy(): + item = item.strip() + if item.startswith("<") and ">" in item: + stopping_strings.append(item.split(">")[0] + ">") + elif item.startswith("[") and "]" in item: + stopping_strings.append(item.split("]")[0] + "]") + if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): stopping_strings += state.pop('stopping_strings') - return list(set(stopping_strings)) + # Remove redundant items that start with another item + result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)] + result = list(set(result)) + + if shared.args.verbose: + logger.info("STOPPING_STRINGS=") + pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result) + print() + + return result def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):