mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-24 13:28:59 +01:00
Consider the softprompt in the maximum prompt length calculation
This commit is contained in:
parent
8b3bb512ef
commit
d910d435cd
14
server.py
14
server.py
@ -247,8 +247,15 @@ def fix_galactica(s):
|
||||
s = s.replace(r'$$', r'$')
|
||||
return s
|
||||
|
||||
def get_max_prompt_length(tokens):
|
||||
global soft_prompt, soft_prompt_tensor
|
||||
max_length = 2048-tokens
|
||||
if soft_prompt:
|
||||
max_length -= soft_prompt_tensor.shape[1]
|
||||
return max_length
|
||||
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
|
||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||
if args.cpu:
|
||||
return input_ids
|
||||
elif args.deepspeed:
|
||||
@ -497,7 +504,8 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe
|
||||
rows = [f"{context.strip()}\n"]
|
||||
i = len(history['internal'])-1
|
||||
count = 0
|
||||
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
|
||||
max_length = get_max_prompt_length(tokens)
|
||||
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
|
||||
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
|
||||
count += 1
|
||||
if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
|
||||
@ -515,7 +523,7 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe
|
||||
rows.append(f"{name1}:")
|
||||
limit = 2
|
||||
|
||||
while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens:
|
||||
while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
rows.pop(1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user