Add a Continue button to chat mode

This commit is contained in:
oobabooga 2023-04-09 20:04:16 -03:00
parent 170e0c05c4
commit d29f4624e9
2 changed files with 44 additions and 12 deletions

View File

@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
@ -39,6 +40,9 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
i = len(shared.history['internal']) - 1 i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
string = shared.history['internal'][i][0] string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
if impersonate: if impersonate:
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2 limit = 2
elif _continue:
limit = 3
else: else:
# Adding the user message # Adding the user message
user_input = fix_newlines(user_input) user_input = fix_newlines(user_input)
@ -56,12 +62,12 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
# Adding the Character prefix # Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
limit = 3 limit = 3
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
rows.pop(1) rows.pop(1)
prompt = ''.join(rows) prompt = ''.join(rows)
if also_return_rows: if also_return_rows:
return prompt, rows return prompt, rows
else: else:
@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
if mode == 'instruct': if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{name1}", f"\n{name2}"]
else: else:
@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
# Defining some variables # Defining some variables
cumulative_reply = '' cumulative_reply = ''
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
just_started = True just_started = True
name1_original = name1 name1_original = name1
visible_text = custom_generate_chat_prompt = None visible_text = custom_generate_chat_prompt = None
@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
if not _continue:
text = apply_extensions(text, "input") text = apply_extensions(text, "input")
# Generating the prompt # Generating the prompt
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'} kwargs = {
'end_of_turn': end_of_turn,
'is_instruct': mode == 'instruct',
'_continue': _continue
}
if custom_generate_chat_prompt is None: if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
else: else:
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
# Yield *Is typing...* # Yield *Is typing...*
if not regenerate: if not any((regenerate, _continue)):
yield shared.history['visible'] + [[visible_text, shared.processing_message]] yield shared.history['visible'] + [[visible_text, shared.processing_message]]
# Generate # Generate
@ -154,9 +166,14 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
return shared.history['visible'] return shared.history['visible']
if just_started: if just_started:
just_started = False just_started = False
if not _continue:
shared.history['internal'].append(['', '']) shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', '']) shared.history['visible'].append(['', ''])
if _continue:
shared.history['internal'][-1] = [text, f'{last_reply[0]} {reply}']
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]} {visible_reply}']
else:
shared.history['internal'][-1] = [text, reply] shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply] shared.history['visible'][-1] = [visible_text, visible_reply]
if not shared.args.no_stream: if not shared.args.no_stream:
@ -220,6 +237,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
else:
# Yield ' ...'
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def remove_last_message(name1, name2, mode): def remove_last_message(name1, name2, mode):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop() last = shared.history['visible'].pop()

View File

@ -327,8 +327,9 @@ def create_interface():
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate') shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
with gr.Row(): with gr.Row():
shared.gradio['Impersonate'] = gr.Button('Impersonate')
shared.gradio['Regenerate'] = gr.Button('Regenerate') shared.gradio['Regenerate'] = gr.Button('Regenerate')
shared.gradio['Continue'] = gr.Button('Continue')
shared.gradio['Impersonate'] = gr.Button('Impersonate')
with gr.Row(): with gr.Row():
shared.gradio['Copy last reply'] = gr.Button('Copy last reply') shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
shared.gradio['Replace last reply'] = gr.Button('Replace last reply') shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
@ -411,7 +412,11 @@ def create_interface():
gen_events.append(shared.gradio['Regenerate'].click( gen_events.append(shared.gradio['Regenerate'].click(
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
)
gen_events.append(shared.gradio['Continue'].click(
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda: chat.save_history(timestamp=False), None, None, show_progress=False) lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
) )