diff --git a/modules/text_generation.py b/modules/text_generation.py index 65a1da94..51e2ddc0 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -193,7 +193,8 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): # Handling the stopping strings stopping_criteria_list = transformers.StoppingCriteriaList() - for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")]): + print(ast.literal_eval(f"[{state['custom_stopping_strings']}]")) + for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): if type(st) is list and len(st) > 0: sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st] stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))