"character greeting" displayed and editable on the fly (#743)

* Add greetings field

* add greeting field and make it interactive

* Minor changes

* Fix a bug

* Simplify clear_chat_log

* Change a label

* Minor change

* Simplifications

* Simplification

* Simplify loading the default character history

* Fix regression

---------

Co-authored-by: oobabooga
This commit is contained in:
OWKenobi 2023-04-03 17:16:15 +02:00 committed by GitHub
parent 8b1f20aa04
commit dcf61a8897
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 41 deletions

View File

@ -35,7 +35,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
prev_user_input = shared.history['internal'][i][0] prev_user_input = shared.history['internal'][i][0]
if len(prev_user_input) > 0 and prev_user_input != '<|BEGIN-VISIBLE-CHAT|>': if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{name1}: {prev_user_input.strip()}\n") rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
i -= 1 i -= 1
@ -198,7 +198,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def remove_last_message(name1, name2): def remove_last_message(name1, name2):
if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop() last = shared.history['visible'].pop()
shared.history['internal'].pop() shared.history['internal'].pop()
else: else:
@ -228,21 +228,13 @@ def replace_last_reply(text, name1, name2):
def clear_html(): def clear_html():
return generate_chat_html([], "", "", shared.character) return generate_chat_html([], "", "", shared.character)
def clear_chat_log(name1, name2): def clear_chat_log(name1, name2, greeting):
if shared.character != 'None':
found = False
for i in range(len(shared.history['internal'])):
if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]]
shared.history['internal'] = [shared.history['internal'][i]]
found = True
break
if not found:
shared.history['visible'] = [] shared.history['visible'] = []
shared.history['internal'] = [] shared.history['internal'] = []
else:
shared.history['internal'] = [] if greeting != '':
shared.history['visible'] = [] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
return generate_chat_output(shared.history['visible'], name1, name2, shared.character) return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
@ -287,11 +279,10 @@ def tokenize_dialogue(dialogue, name1, name2):
return history return history
def save_history(timestamp=True): def save_history(timestamp=True):
prefix = '' if shared.character == 'None' else f"{shared.character}_"
if timestamp: if timestamp:
fname = f"{prefix}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else: else:
fname = f"{prefix}persistent.json" fname = f"{shared.character}_persistent.json"
if not Path('logs').exists(): if not Path('logs').exists():
Path('logs').mkdir() Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f: with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
@ -322,14 +313,6 @@ def load_history(file, name1, name2):
shared.history['internal'] = tokenize_dialogue(file, name1, name2) shared.history['internal'] = tokenize_dialogue(file, name1, name2)
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def load_default_history(name1, name2):
shared.character = 'None'
if Path('logs/persistent.json').exists():
load_history(open(Path('logs/persistent.json'), 'rb').read(), name1, name2)
else:
shared.history['internal'] = []
shared.history['visible'] = []
def replace_character_names(text, name1, name2): def replace_character_names(text, name1, name2):
text = text.replace('{{user}}', name1).replace('{{char}}', name2) text = text.replace('{{user}}', name1).replace('{{char}}', name2)
return text.replace('<USER>', name1).replace('<BOT>', name2) return text.replace('<USER>', name1).replace('<BOT>', name2)
@ -343,20 +326,24 @@ def build_pygmalion_style_context(data):
context = f"{context.strip()}\n<START>\n" context = f"{context.strip()}\n<START>\n"
return context return context
def load_character(_character, name1, name2): def load_character(character, name1, name2):
shared.character = character
shared.history['internal'] = [] shared.history['internal'] = []
shared.history['visible'] = [] shared.history['visible'] = []
if _character != 'None': greeting = ""
shared.character = _character
if character != 'None':
for extension in ["yml", "yaml", "json"]: for extension in ["yml", "yaml", "json"]:
filepath = Path(f'characters/{_character}.{extension}') filepath = Path(f'characters/{character}.{extension}')
if filepath.exists(): if filepath.exists():
break break
file_contents = open(filepath, 'r', encoding='utf-8').read() file_contents = open(filepath, 'r', encoding='utf-8').read()
data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents) data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
if 'your_name' in data and data['your_name'] != '':
name1 = data['your_name']
name2 = data['name'] if 'name' in data else data['char_name'] name2 = data['name'] if 'name' in data else data['char_name']
for field in ['context', 'greeting', 'example_dialogue', 'char_persona', 'char_greeting', 'world_scenario']: for field in ['context', 'greeting', 'example_dialogue', 'char_persona', 'char_greeting', 'world_scenario']:
if field in data: if field in data:
data[field] = replace_character_names(data[field], name1, name2) data[field] = replace_character_names(data[field], name1, name2)
@ -371,20 +358,25 @@ def load_character(_character, name1, name2):
if 'example_dialogue' in data and data['example_dialogue'] != '': if 'example_dialogue' in data and data['example_dialogue'] != '':
context += f"{data['example_dialogue'].strip()}\n" context += f"{data['example_dialogue'].strip()}\n"
if greeting_field in data and len(data[greeting_field].strip()) > 0: if greeting_field in data and len(data[greeting_field].strip()) > 0:
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data[greeting_field]]] greeting = data[greeting_field]
shared.history['visible'] += [['', apply_extensions(data[greeting_field], "output")]]
else: else:
shared.character = 'None'
context = shared.settings['context'] context = shared.settings['context']
name2 = shared.settings['name2'] name2 = shared.settings['name2']
greeting = shared.settings['greeting']
if Path(f'logs/{shared.character}_persistent.json').exists(): if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
elif greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
if shared.args.cai_chat: if shared.args.cai_chat:
return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character) return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else: else:
return name2, context, shared.history['visible'] return name1, name2, greeting, context, shared.history['visible']
def load_default_history(name1, name2):
load_character("None", name1, name2)
def upload_character(json_file, img, tavern=False): def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8') json_file = json_file if type(json_file) == str else json_file.decode('utf-8')

View File

@ -31,6 +31,7 @@ settings = {
'name1': 'You', 'name1': 'You',
'name2': 'Assistant', 'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'greeting': 'Hello there!',
'stop_at_newline': False, 'stop_at_newline': False,
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,

View File

@ -317,8 +317,9 @@ def create_interface():
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab("Character", elem_id="chat-settings"):
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character''s name')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context') shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context')
with gr.Row(): with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@ -381,7 +382,7 @@ def create_interface():
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
@ -396,7 +397,7 @@ def create_interface():
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']]) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], []) shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])