Obtain the EOT token from the jinja template (attempt)

To use as a stopping string.
This commit is contained in:
oobabooga 2024-06-30 15:09:22 -07:00
parent 3e3f8637d6
commit ed01322763

View File

@ -3,6 +3,7 @@ import copy
import functools import functools
import html import html
import json import json
import pprint
import re import re
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
@ -259,10 +260,27 @@ def get_stopping_strings(state):
suffix_bot + prefix_user, suffix_bot + prefix_user,
] ]
# Try to find the EOT token
for item in stopping_strings.copy():
item = item.strip()
if item.startswith("<") and ">" in item:
stopping_strings.append(item.split(">")[0] + ">")
elif item.startswith("[") and "]" in item:
stopping_strings.append(item.split("]")[0] + "]")
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings') stopping_strings += state.pop('stopping_strings')
return list(set(stopping_strings)) # Remove redundant items that start with another item
result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)]
result = list(set(result))
if shared.args.verbose:
logger.info("STOPPING_STRINGS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result)
print()
return result
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False): def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):