mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 05:48:55 +01:00
Simplify encode() function
This commit is contained in:
parent
afc2b0f4c8
commit
3f05cf5ddd
11
server.py
11
server.py
@ -168,16 +168,13 @@ def fix_galactica(s):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||||
if args.cpu:
|
|
||||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
|
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
|
||||||
else:
|
if args.cpu:
|
||||||
torch.cuda.empty_cache()
|
|
||||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
|
|
||||||
|
|
||||||
if not args.deepspeed:
|
|
||||||
return input_ids
|
return input_ids
|
||||||
else:
|
elif args.deepspeed:
|
||||||
return input_ids.to(device=local_rank)
|
return input_ids.to(device=local_rank)
|
||||||
|
else:
|
||||||
|
return input_ids.cuda()
|
||||||
|
|
||||||
def decode(output_ids):
|
def decode(output_ids):
|
||||||
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
|
reply = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user