From de6a09dc7f7d5a5d8496cfa1598abb4ff5ee1338 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 23 Mar 2023 00:12:40 -0300 Subject: [PATCH] Properly separate the original prompt from the reply --- modules/text_generation.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 610bd4fc..d539f6d4 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -136,6 +136,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = encode(question, max_new_tokens) original_input_ids = input_ids output = input_ids[0] + cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] if eos_token is not None: @@ -146,9 +147,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi t = encode(stopping_string, 0, add_special_tokens=False) stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) - generate_params = { - 'use_cache': not shared.args.no_cache, - } + generate_params = {} if not shared.args.flexgen: generate_params.update({ "max_new_tokens": max_new_tokens, @@ -175,6 +174,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi "temperature": temperature, "stop": eos_token_ids[-1], }) + if shared.args.no_cache: + generate_params.update({"use_cache": False}) if shared.args.deepspeed: generate_params.update({"synced_gpus": True}) if shared.soft_prompt: @@ -194,9 +195,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - reply = decode(output) if not (shared.args.chat or shared.args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") + new_tokens = len(output) - len(input_ids[0]) + reply = decode(output[-new_tokens:]) + reply = original_question + apply_extensions(reply, "output") + else: + reply = decode(output) yield formatted_outputs(reply, shared.model_name) @@ -219,10 +223,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi for output in generator: if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - reply = decode(output) - if not (shared.args.chat or shared.args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") + new_tokens = len(output) - len(input_ids[0]) + reply = decode(output[-new_tokens:]) + reply = original_question + apply_extensions(reply, "output") + else: + reply = decode(output) if output[-1] in eos_token_ids: break @@ -238,10 +244,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi output = shared.model.generate(**generate_params)[0] if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - reply = decode(output) - if not (shared.args.chat or shared.args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") + new_tokens = len(output) - len(original_input_ids[0]) + reply = decode(output[-new_tokens:]) + reply = original_question + apply_extensions(reply, "output") + else: + reply = decode(output) if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): break