mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
The soft prompt length must be considered here too
This commit is contained in:
parent
a6ddbbfc77
commit
596732a981
@ -505,11 +505,17 @@ def clean_chat_message(text):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
|
def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
|
||||||
|
global soft_prompt, soft_prompt_tensor
|
||||||
|
|
||||||
text = clean_chat_message(text)
|
text = clean_chat_message(text)
|
||||||
rows = [f"{context.strip()}\n"]
|
rows = [f"{context.strip()}\n"]
|
||||||
i = len(history['internal'])-1
|
i = len(history['internal'])-1
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
|
if soft_prompt:
|
||||||
|
chat_prompt_size -= soft_prompt_tensor.shape[1]
|
||||||
max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
|
max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
|
||||||
|
|
||||||
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
|
||||||
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
|
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
|
||||||
count += 1
|
count += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user