Stop generating at \nYou: in chat mode

This commit is contained in:
oobabooga 2023-01-25 10:17:55 -03:00
parent 54e77acac4
commit 3b8f0021cc
2 changed files with 20 additions and 9 deletions

View File

@ -153,6 +153,6 @@ Pull requests, suggestions, and issue reports are welcome.
## Credits
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
- Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py
- Pygmalion preset, code for early stopping in chat mode: https://github.com/PygmalionAI/gradio-ui/
- Verbose preset: Anonymous 4chan user.
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui

View File

@ -14,6 +14,7 @@ import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from modules.html_generator import *
from modules.ui import *
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
transformers.logging.set_verbosity_error()
@ -135,12 +136,12 @@ def fix_galactica(s):
s = s.replace(r'$$', r'$')
return s
def encode(prompt, tokens):
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if args.cpu:
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
else:
torch.cuda.empty_cache()
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
return input_ids
def decode(output_ids):
@ -161,7 +162,7 @@ def formatted_outputs(reply, model_name):
else:
return reply
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
global model, tokenizer, model_name, loaded_preset, preset
if selected_model != model_name:
@ -179,11 +180,22 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
cuda = "" if args.cpu else ".cuda()"
n = None if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
input_ids = encode(question, tokens)
# The stopping_criteria code below was copied from
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
if stopping_string is not None:
t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list = transformers.StoppingCriteriaList([
_SentinelTokenStoppingCriteria(
sentinel_token_ids=t,
starting_idx=len(input_ids[0]))
])
else:
stopping_criteria_list = None
# Generate the entire reply at once
if args.no_stream:
t0 = time.time()
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
reply = decode(output[0])
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
@ -194,11 +206,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
yield formatted_outputs(question, model_name)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
for i in tqdm(range(tokens)):
output = eval(f"model.generate(input_ids, {preset}){cuda}")
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
reply = decode(output[0])
if eos_token is not None and reply[-1] == eos_token:
break
yield formatted_outputs(reply, model_name)
input_ids = output
@ -289,7 +300,7 @@ if args.chat or args.cai_chat:
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
history.append(['', ''])
eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
next_character_found = False
previous_idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", question)]