mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add a Continue button to chat mode
This commit is contained in:
parent
170e0c05c4
commit
d29f4624e9
@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
||||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||||
|
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
rows = [f"{context.strip()}\n"]
|
rows = [f"{context.strip()}\n"]
|
||||||
|
|
||||||
@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
|
|
||||||
i = len(shared.history['internal']) - 1
|
i = len(shared.history['internal']) - 1
|
||||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
if _continue and i == len(shared.history['internal']) - 1:
|
||||||
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||||
|
else:
|
||||||
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||||
string = shared.history['internal'][i][0]
|
string = shared.history['internal'][i][0]
|
||||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||||
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
||||||
@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
if impersonate:
|
if impersonate:
|
||||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||||
limit = 2
|
limit = 2
|
||||||
|
elif _continue:
|
||||||
|
limit = 3
|
||||||
else:
|
else:
|
||||||
# Adding the user message
|
# Adding the user message
|
||||||
user_input = fix_newlines(user_input)
|
user_input = fix_newlines(user_input)
|
||||||
@ -56,12 +62,12 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
|
|
||||||
# Adding the Character prefix
|
# Adding the Character prefix
|
||||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||||
|
|
||||||
limit = 3
|
limit = 3
|
||||||
|
|
||||||
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
||||||
rows.pop(1)
|
rows.pop(1)
|
||||||
prompt = ''.join(rows)
|
prompt = ''.join(rows)
|
||||||
|
|
||||||
if also_return_rows:
|
if also_return_rows:
|
||||||
return prompt, rows
|
return prompt, rows
|
||||||
else:
|
else:
|
||||||
@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
|
||||||
if mode == 'instruct':
|
if mode == 'instruct':
|
||||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||||
else:
|
else:
|
||||||
@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
# Defining some variables
|
# Defining some variables
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
|
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||||
just_started = True
|
just_started = True
|
||||||
name1_original = name1
|
name1_original = name1
|
||||||
visible_text = custom_generate_chat_prompt = None
|
visible_text = custom_generate_chat_prompt = None
|
||||||
@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
if visible_text is None:
|
if visible_text is None:
|
||||||
visible_text = text
|
visible_text = text
|
||||||
text = apply_extensions(text, "input")
|
if not _continue:
|
||||||
|
text = apply_extensions(text, "input")
|
||||||
|
|
||||||
# Generating the prompt
|
# Generating the prompt
|
||||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
kwargs = {
|
||||||
|
'end_of_turn': end_of_turn,
|
||||||
|
'is_instruct': mode == 'instruct',
|
||||||
|
'_continue': _continue
|
||||||
|
}
|
||||||
if custom_generate_chat_prompt is None:
|
if custom_generate_chat_prompt is None:
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
if not regenerate:
|
if not any((regenerate, _continue)):
|
||||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
@ -154,11 +166,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
return shared.history['visible']
|
return shared.history['visible']
|
||||||
if just_started:
|
if just_started:
|
||||||
just_started = False
|
just_started = False
|
||||||
shared.history['internal'].append(['', ''])
|
if not _continue:
|
||||||
shared.history['visible'].append(['', ''])
|
shared.history['internal'].append(['', ''])
|
||||||
|
shared.history['visible'].append(['', ''])
|
||||||
|
|
||||||
shared.history['internal'][-1] = [text, reply]
|
if _continue:
|
||||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
shared.history['internal'][-1] = [text, f'{last_reply[0]} {reply}']
|
||||||
|
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]} {visible_reply}']
|
||||||
|
else:
|
||||||
|
shared.history['internal'][-1] = [text, reply]
|
||||||
|
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||||
if not shared.args.no_stream:
|
if not shared.args.no_stream:
|
||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
@ -220,6 +237,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
|
|||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
|
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
else:
|
||||||
|
# Yield ' ...'
|
||||||
|
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
|
||||||
|
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def remove_last_message(name1, name2, mode):
|
def remove_last_message(name1, name2, mode):
|
||||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
last = shared.history['visible'].pop()
|
last = shared.history['visible'].pop()
|
||||||
|
@ -327,8 +327,9 @@ def create_interface():
|
|||||||
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
||||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
|
||||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||||
|
shared.gradio['Continue'] = gr.Button('Continue')
|
||||||
|
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||||
@ -411,7 +412,11 @@ def create_interface():
|
|||||||
|
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user