mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
Stop generating in chat mode when \nYou: is generated
This commit is contained in:
parent
022960a087
commit
df2e910421
14
server.py
14
server.py
@ -143,7 +143,6 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
|
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
|
||||||
cuda = ".cuda()" if args.cpu else ""
|
cuda = ".cuda()" if args.cpu else ""
|
||||||
for i in range(tokens):
|
for i in range(tokens):
|
||||||
|
|
||||||
if eos_token is None:
|
if eos_token is None:
|
||||||
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
||||||
else:
|
else:
|
||||||
@ -246,6 +245,7 @@ elif args.chat or args.cai_chat:
|
|||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
||||||
reply = i[0]
|
reply = i[0]
|
||||||
|
next_character_found = False
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
idx = reply.rfind(question[-1024:])
|
idx = reply.rfind(question[-1024:])
|
||||||
@ -256,6 +256,7 @@ elif args.chat or args.cai_chat:
|
|||||||
idx = reply.find(f"\n{name1}:")
|
idx = reply.find(f"\n{name1}:")
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
reply = reply[:idx]
|
reply = reply[:idx]
|
||||||
|
next_character_found = True
|
||||||
reply = clean_chat_message(reply)
|
reply = clean_chat_message(reply)
|
||||||
|
|
||||||
history[-1] = [text, reply]
|
history[-1] = [text, reply]
|
||||||
@ -263,14 +264,17 @@ elif args.chat or args.cai_chat:
|
|||||||
# Prevent the chat log from flashing if something like "\nYo" is generated just
|
# Prevent the chat log from flashing if something like "\nYo" is generated just
|
||||||
# before "\nYou:" is completed
|
# before "\nYou:" is completed
|
||||||
tmp = f"\n{name1}:"
|
tmp = f"\n{name1}:"
|
||||||
found = False
|
next_character_substring_found = False
|
||||||
for j in range(1, len(tmp)):
|
for j in range(1, len(tmp)+1):
|
||||||
if reply[-j:] == tmp[:j]:
|
if reply[-j:] == tmp[:j]:
|
||||||
found = True
|
next_character_substring_found = True
|
||||||
|
|
||||||
if not found:
|
if not next_character_substring_found:
|
||||||
yield history
|
yield history
|
||||||
|
|
||||||
|
if next_character_found:
|
||||||
|
break
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||||
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||||
yield generate_chat_html(history, name1, name2)
|
yield generate_chat_html(history, name1, name2)
|
||||||
|
Loading…
Reference in New Issue
Block a user