Truncate prompts to 2048 characters

This commit is contained in:
oobabooga 2023-01-16 13:43:23 -03:00
parent 99d24bdbfe
commit 54bf55372b

View File

@ -96,6 +96,7 @@ def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
@ -134,10 +135,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
if not args.cpu: if not args.cpu:
torch.cuda.empty_cache() torch.cuda.empty_cache()
input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda() input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
cuda = ".cuda()" cuda = ".cuda()"
else: else:
input_ids = tokenizer.encode(str(question), return_tensors='pt') input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens)
cuda = "" cuda = ""
if eos_token is None: if eos_token is None:
@ -231,10 +232,12 @@ elif args.chat or args.cai_chat:
if check: if check:
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0] reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
reply = reply[len(question):].split('\n')[0].strip() idx = reply.rfind(question[-500:])
reply = reply[idx+min(500, len(question)):].split('\n')[0].strip()
else: else:
reply = generate_reply(question, tokens, inference_settings, selected_model)[0] reply = generate_reply(question, tokens, inference_settings, selected_model)[0]
reply = reply[len(question):] idx = reply.rfind(question[-500:])
reply = reply[idx+min(500, len(question)):]
idx = reply.find(f"\n{name1}:") idx = reply.find(f"\n{name1}:")
if idx != -1: if idx != -1:
reply = reply[:idx] reply = reply[:idx]