Bug fix: when generation fails, save the sent message (#4915)

This commit is contained in:
oobabooga 2023-12-15 01:01:45 -03:00 committed by GitHub
parent 11f082e417
commit 2cb5b68ad9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -215,40 +215,47 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
yield output yield output
return return
just_started = True
visible_text = None visible_text = None
stopping_strings = get_stopping_strings(state) stopping_strings = get_stopping_strings(state)
is_stream = state['stream'] is_stream = state['stream']
# Prepare the input # Prepare the input
if not any((regenerate, _continue)): if not (regenerate or _continue):
visible_text = html.escape(text) visible_text = html.escape(text)
# Apply extensions # Apply extensions
text, visible_text = apply_extensions('chat_input', text, visible_text, state) text, visible_text = apply_extensions('chat_input', text, visible_text, state)
text = apply_extensions('input', text, state, is_chat=True) text = apply_extensions('input', text, state, is_chat=True)
output['internal'].append([text, ''])
output['visible'].append([visible_text, ''])
# *Is typing...* # *Is typing...*
if loading_message: if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} yield {
'visible': output['visible'][:-1] + [[output['visible'][-1][0], shared.processing_message]],
'internal': output['internal']
}
else: else:
text, visible_text = output['internal'][-1][0], output['visible'][-1][0] text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate: if regenerate:
output['visible'].pop()
output['internal'].pop()
# *Is typing...*
if loading_message: if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} yield {
'visible': output['visible'][:-1] + [[visible_text, shared.processing_message]],
'internal': output['internal'][:-1] + [[text, '']]
}
elif _continue: elif _continue:
last_reply = [output['internal'][-1][1], output['visible'][-1][1]] last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
if loading_message: if loading_message:
yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']} yield {
'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']],
'internal': output['internal']
}
# Generate the prompt # Generate the prompt
kwargs = { kwargs = {
'_continue': _continue, '_continue': _continue,
'history': output, 'history': output if _continue else {k: v[:-1] for k, v in output.items()}
} }
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs) prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
if prompt is None: if prompt is None:
@ -270,12 +277,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
yield output yield output
return return
if just_started:
just_started = False
if not _continue:
output['internal'].append(['', ''])
output['visible'].append(['', ''])
if _continue: if _continue:
output['internal'][-1] = [text, last_reply[0] + reply] output['internal'][-1] = [text, last_reply[0] + reply]
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply] output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]