diff --git a/server.py b/server.py index de93dccc..be06db9f 100644 --- a/server.py +++ b/server.py @@ -168,16 +168,13 @@ def fix_galactica(s): return s def encode(prompt, tokens_to_generate=0, add_special_tokens=True): + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens) if args.cpu: - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens) - else: - torch.cuda.empty_cache() - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda() - - if not args.deepspeed: return input_ids - else: + elif args.deepspeed: return input_ids.to(device=local_rank) + else: + return input_ids.cuda() def decode(output_ids): reply = tokenizer.decode(output_ids, skip_special_tokens=True)