Various "impersonate" fixes

This commit is contained in:
oobabooga 2023-05-21 22:54:28 -03:00
parent e116d31180
commit dcc3e54005

View File

@ -53,7 +53,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
history = kwargs.get('history', shared.history)['internal'] history = kwargs.get('history', shared.history)['internal']
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
# FInd the maximum prompt size # Find the maximum prompt size
chat_prompt_size = state['chat_prompt_size'] chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt: if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
@ -295,16 +295,19 @@ def impersonate_wrapper(text, state):
for reply in generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True): for reply in generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True):
reply = cumulative_reply + reply reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, state) reply, next_character_found = extract_message_from_reply(reply, state)
yield reply yield reply.lstrip(' ')
if shared.stop_everything:
return
if next_character_found: if next_character_found:
break break
if reply in [None, '']: if reply in [None, cumulative_reply]:
break break
else: else:
cumulative_reply = reply cumulative_reply = reply
yield cumulative_reply yield cumulative_reply.lstrip(' ')
def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True): def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True):