Ensure that the chat prompt will always contain < 2048 tokens

This way, we can keep the context string at the top of the prompt
even if you keep talking to the bot for hours.

Before this commit, the prompt would be simply truncated and the
context string would eventually be lost.
This commit is contained in:
oobabooga 2023-01-17 20:16:23 -03:00
parent 6456777b09
commit ca13acdfa0

View File

@ -116,6 +116,14 @@ def fix_galactica(s):
s = s.replace(r'$$', r'$') s = s.replace(r'$$', r'$')
return s return s
def encode(prompt, tokens):
if not args.cpu:
torch.cuda.empty_cache()
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
else:
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
return input_ids
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None): def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
global model, tokenizer, model_name, loaded_preset, preset global model, tokenizer, model_name, loaded_preset, preset
@ -131,14 +139,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
preset = infile.read() preset = infile.read()
loaded_preset = inference_settings loaded_preset = inference_settings
if not args.cpu: input_ids = encode(question, tokens)
torch.cuda.empty_cache()
input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
cuda = ".cuda()"
else:
input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens)
cuda = ""
cuda = ".cuda()" if args.cpu else ""
if eos_token is None: if eos_token is None:
output = eval(f"model.generate(input_ids, {preset}){cuda}") output = eval(f"model.generate(input_ids, {preset}){cuda}")
else: else:
@ -217,16 +220,20 @@ elif args.chat or args.cai_chat:
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
text = chat_response_cleaner(text) text = chat_response_cleaner(text)
question = f"{context}\n\n" rows = [f"{context}\n\n"]
for i in range(len(history)): i = len(history)-1
if args.cai_chat: while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
question += f"{name1}: {history[i][0].strip()}\n" rows.insert(1, f"{name2}: {history[i][1].strip()}\n")
question += f"{name2}: {history[i][1].strip()}\n" rows.insert(1, f"{name1}: {history[i][0].strip()}\n")
else: i -= 1
question += f"{name1}: {history[i][0][3:-5].strip()}\n" rows.append(f"{name1}: {text}\n")
question += f"{name2}: {history[i][1][3:-5].strip()}\n" rows.append(f"{name2}:")
question += f"{name1}: {text}\n"
question += f"{name2}:" while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens:
rows.pop(1)
rows.pop(1)
question = ''.join(rows)
if check: if check:
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0] reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]