From dcf61a8897bde4a04518ec17258bc4912f5229b9 Mon Sep 17 00:00:00 2001 From: OWKenobi Date: Mon, 3 Apr 2023 17:16:15 +0200 Subject: [PATCH] "character greeting" displayed and editable on the fly (#743) * Add greetings field * add greeting field and make it interactive * Minor changes * Fix a bug * Simplify clear_chat_log * Change a label * Minor change * Simplifications * Simplification * Simplify loading the default character history * Fix regression --------- Co-authored-by: oobabooga --- modules/chat.py | 66 +++++++++++++++++++++-------------------------- modules/shared.py | 1 + server.py | 9 ++++--- 3 files changed, 35 insertions(+), 41 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 0b010b4e..cd8639c2 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -35,7 +35,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") prev_user_input = shared.history['internal'][i][0] - if len(prev_user_input) > 0 and prev_user_input != '<|BEGIN-VISIBLE-CHAT|>': + if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']: rows.insert(1, f"{name1}: {prev_user_input.strip()}\n") i -= 1 @@ -198,7 +198,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) def remove_last_message(name1, name2): - if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': + if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': last = shared.history['visible'].pop() shared.history['internal'].pop() else: @@ -228,21 +228,13 @@ def replace_last_reply(text, name1, name2): def clear_html(): return generate_chat_html([], "", "", shared.character) -def clear_chat_log(name1, name2): - if shared.character != 'None': - found = False - for i in range(len(shared.history['internal'])): - if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]: - shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]] - shared.history['internal'] = [shared.history['internal'][i]] - found = True - break - if not found: - shared.history['visible'] = [] - shared.history['internal'] = [] - else: - shared.history['internal'] = [] - shared.history['visible'] = [] +def clear_chat_log(name1, name2, greeting): + shared.history['visible'] = [] + shared.history['internal'] = [] + + if greeting != '': + shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] + shared.history['visible'] += [['', apply_extensions(greeting, "output")]] return generate_chat_output(shared.history['visible'], name1, name2, shared.character) @@ -287,11 +279,10 @@ def tokenize_dialogue(dialogue, name1, name2): return history def save_history(timestamp=True): - prefix = '' if shared.character == 'None' else f"{shared.character}_" if timestamp: - fname = f"{prefix}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" + fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" else: - fname = f"{prefix}persistent.json" + fname = f"{shared.character}_persistent.json" if not Path('logs').exists(): Path('logs').mkdir() with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f: @@ -322,14 +313,6 @@ def load_history(file, name1, name2): shared.history['internal'] = tokenize_dialogue(file, name1, name2) shared.history['visible'] = copy.deepcopy(shared.history['internal']) -def load_default_history(name1, name2): - shared.character = 'None' - if Path('logs/persistent.json').exists(): - load_history(open(Path('logs/persistent.json'), 'rb').read(), name1, name2) - else: - shared.history['internal'] = [] - shared.history['visible'] = [] - def replace_character_names(text, name1, name2): text = text.replace('{{user}}', name1).replace('{{char}}', name2) return text.replace('', name1).replace('', name2) @@ -343,20 +326,24 @@ def build_pygmalion_style_context(data): context = f"{context.strip()}\n\n" return context -def load_character(_character, name1, name2): +def load_character(character, name1, name2): + shared.character = character shared.history['internal'] = [] shared.history['visible'] = [] - if _character != 'None': - shared.character = _character + greeting = "" + if character != 'None': for extension in ["yml", "yaml", "json"]: - filepath = Path(f'characters/{_character}.{extension}') + filepath = Path(f'characters/{character}.{extension}') if filepath.exists(): break file_contents = open(filepath, 'r', encoding='utf-8').read() data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents) + if 'your_name' in data and data['your_name'] != '': + name1 = data['your_name'] name2 = data['name'] if 'name' in data else data['char_name'] + for field in ['context', 'greeting', 'example_dialogue', 'char_persona', 'char_greeting', 'world_scenario']: if field in data: data[field] = replace_character_names(data[field], name1, name2) @@ -371,20 +358,25 @@ def load_character(_character, name1, name2): if 'example_dialogue' in data and data['example_dialogue'] != '': context += f"{data['example_dialogue'].strip()}\n" if greeting_field in data and len(data[greeting_field].strip()) > 0: - shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data[greeting_field]]] - shared.history['visible'] += [['', apply_extensions(data[greeting_field], "output")]] + greeting = data[greeting_field] else: - shared.character = 'None' context = shared.settings['context'] name2 = shared.settings['name2'] + greeting = shared.settings['greeting'] if Path(f'logs/{shared.character}_persistent.json').exists(): load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) + elif greeting != "": + shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] + shared.history['visible'] += [['', apply_extensions(greeting, "output")]] if shared.args.cai_chat: - return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character) + return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character) else: - return name2, context, shared.history['visible'] + return name1, name2, greeting, context, shared.history['visible'] + +def load_default_history(name1, name2): + load_character("None", name1, name2) def upload_character(json_file, img, tavern=False): json_file = json_file if type(json_file) == str else json_file.decode('utf-8') diff --git a/modules/shared.py b/modules/shared.py index c4225586..4e83c53e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -31,6 +31,7 @@ settings = { 'name1': 'You', 'name2': 'Assistant', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', + 'greeting': 'Hello there!', 'stop_at_newline': False, 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, diff --git a/server.py b/server.py index 2755e892..04650971 100644 --- a/server.py +++ b/server.py @@ -317,8 +317,9 @@ def create_interface(): with gr.Tab("Character", elem_id="chat-settings"): shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') - shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name') - shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context') + shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character''s name') + shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting') + shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context') with gr.Row(): shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') @@ -381,7 +382,7 @@ def create_interface(): clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) - shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) + shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display']) shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) @@ -396,7 +397,7 @@ def create_interface(): shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']]) + shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']]) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])