mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Truncate prompts to 2048 characters
This commit is contained in:
parent
99d24bdbfe
commit
54bf55372b
11
server.py
11
server.py
@ -96,6 +96,7 @@ def load_model(model_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
|
||||
tokenizer.truncation_side = 'left'
|
||||
|
||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||
return model, tokenizer
|
||||
@ -134,10 +135,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
|
||||
if not args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda()
|
||||
input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
|
||||
cuda = ".cuda()"
|
||||
else:
|
||||
input_ids = tokenizer.encode(str(question), return_tensors='pt')
|
||||
input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens)
|
||||
cuda = ""
|
||||
|
||||
if eos_token is None:
|
||||
@ -231,10 +232,12 @@ elif args.chat or args.cai_chat:
|
||||
|
||||
if check:
|
||||
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
|
||||
reply = reply[len(question):].split('\n')[0].strip()
|
||||
idx = reply.rfind(question[-500:])
|
||||
reply = reply[idx+min(500, len(question)):].split('\n')[0].strip()
|
||||
else:
|
||||
reply = generate_reply(question, tokens, inference_settings, selected_model)[0]
|
||||
reply = reply[len(question):]
|
||||
idx = reply.rfind(question[-500:])
|
||||
reply = reply[idx+min(500, len(question)):]
|
||||
idx = reply.find(f"\n{name1}:")
|
||||
if idx != -1:
|
||||
reply = reply[:idx]
|
||||
|
Loading…
Reference in New Issue
Block a user