From 9d3c6d2dc32474b2a6ea3c2a1d4e50b2fef69026 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 16 Apr 2023 01:40:47 -0300 Subject: [PATCH] Fix a bug --- modules/text_generation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 65a1da94..51e2ddc0 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -193,7 +193,8 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): # Handling the stopping strings stopping_criteria_list = transformers.StoppingCriteriaList() - for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")]): + print(ast.literal_eval(f"[{state['custom_stopping_strings']}]")) + for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): if type(st) is list and len(st) > 0: sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st] stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))