Simplify encode() function

This commit is contained in:
oobabooga 2023-02-02 13:31:32 -03:00
parent afc2b0f4c8
commit 3f05cf5ddd

View File

@ -168,16 +168,13 @@ def fix_galactica(s):
return s return s
def encode(prompt, tokens_to_generate=0, add_special_tokens=True): def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if args.cpu:
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens) input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
else: if args.cpu:
torch.cuda.empty_cache()
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
if not args.deepspeed:
return input_ids return input_ids
else: elif args.deepspeed:
return input_ids.to(device=local_rank) return input_ids.to(device=local_rank)
else:
return input_ids.cuda()
def decode(output_ids): def decode(output_ids):
reply = tokenizer.decode(output_ids, skip_special_tokens=True) reply = tokenizer.decode(output_ids, skip_special_tokens=True)