From ed013227630bb2a2e76cc4d17f43261c31e8bba2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 30 Jun 2024 15:09:22 -0700 Subject: [PATCH] Obtain the EOT token from the jinja template (attempt) To use as a stopping string. --- modules/chat.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/modules/chat.py b/modules/chat.py index 5d2bdd63..920c0f7b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -3,6 +3,7 @@ import copy import functools import html import json +import pprint import re from datetime import datetime from functools import partial @@ -259,10 +260,27 @@ def get_stopping_strings(state): suffix_bot + prefix_user, ] + # Try to find the EOT token + for item in stopping_strings.copy(): + item = item.strip() + if item.startswith("<") and ">" in item: + stopping_strings.append(item.split(">")[0] + ">") + elif item.startswith("[") and "]" in item: + stopping_strings.append(item.split("]")[0] + "]") + if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): stopping_strings += state.pop('stopping_strings') - return list(set(stopping_strings)) + # Remove redundant items that start with another item + result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)] + result = list(set(result)) + + if shared.args.verbose: + logger.info("STOPPING_STRINGS=") + pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result) + print() + + return result def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):