From 3b8f0021cc3b2134b08a0e5178d87828b0b84b9c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 25 Jan 2023 10:17:55 -0300 Subject: [PATCH] Stop generating at \nYou: in chat mode --- README.md | 2 +- server.py | 27 +++++++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 76dbaa09..ca134a0a 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,6 @@ Pull requests, suggestions, and issue reports are welcome. ## Credits - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets -- Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py +- Pygmalion preset, code for early stopping in chat mode: https://github.com/PygmalionAI/gradio-ui/ - Verbose preset: Anonymous 4chan user. - Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui diff --git a/server.py b/server.py index 79556f4a..0e6983e7 100644 --- a/server.py +++ b/server.py @@ -14,6 +14,7 @@ import transformers from transformers import AutoTokenizer, AutoModelForCausalLM from modules.html_generator import * from modules.ui import * +from modules.stopping_criteria import _SentinelTokenStoppingCriteria transformers.logging.set_verbosity_error() @@ -135,12 +136,12 @@ def fix_galactica(s): s = s.replace(r'$$', r'$') return s -def encode(prompt, tokens): +def encode(prompt, tokens_to_generate=0, add_special_tokens=True): if args.cpu: - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens) + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens) else: torch.cuda.empty_cache() - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda() + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda() return input_ids def decode(output_ids): @@ -161,7 +162,7 @@ def formatted_outputs(reply, model_name): else: return reply -def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None): +def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None): global model, tokenizer, model_name, loaded_preset, preset if selected_model != model_name: @@ -179,11 +180,22 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok cuda = "" if args.cpu else ".cuda()" n = None if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1] input_ids = encode(question, tokens) + # The stopping_criteria code below was copied from + # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py + if stopping_string is not None: + t = encode(stopping_string, 0, add_special_tokens=False) + stopping_criteria_list = transformers.StoppingCriteriaList([ + _SentinelTokenStoppingCriteria( + sentinel_token_ids=t, + starting_idx=len(input_ids[0])) + ]) + else: + stopping_criteria_list = None # Generate the entire reply at once if args.no_stream: t0 = time.time() - output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") + output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") reply = decode(output[0]) t1 = time.time() print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)") @@ -194,11 +206,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok yield formatted_outputs(question, model_name) preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') for i in tqdm(range(tokens)): - output = eval(f"model.generate(input_ids, {preset}){cuda}") + output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") reply = decode(output[0]) if eos_token is not None and reply[-1] == eos_token: break - yield formatted_outputs(reply, model_name) input_ids = output @@ -289,7 +300,7 @@ if args.chat or args.cai_chat: question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) history.append(['', '']) eos_token = '\n' if check else None - for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token): + for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"): next_character_found = False previous_idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", question)]