From 65d8a24a6df424e01412c71de463ae54fe5626ea Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 22:28:49 -0300 Subject: [PATCH] Show profile pictures in the Character tab --- README.md | 2 +- extensions/gallery/script.py | 5 ++- modules/chat.py | 59 ++++++++++++++++++++++++++---------- modules/html_generator.py | 15 +++------ modules/shared.py | 2 +- server.py | 24 +++++++++------ 6 files changed, 66 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 373f83f4..065a9a30 100644 --- a/README.md +++ b/README.md @@ -175,7 +175,7 @@ Optionally, you can use the following command-line flags: | `-h`, `--help` | show this help message and exit | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--chat` | Launch the web UI in chat mode.| -| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | +| `--cai-chat` | Launch the web UI in chat mode with a style similar to the Character.AI website. | | `--model MODEL` | Name of the model to load by default. | | `--lora LORA` | Name of the LoRA to apply to the model by default. | | `--model-dir MODEL_DIR` | Path to directory with all the models | diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index 034506d4..51ab6434 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -2,9 +2,8 @@ from pathlib import Path import gradio as gr -from modules.chat import load_character from modules.html_generator import get_image_cache -from modules.shared import gradio, settings +from modules.shared import gradio def generate_css(): @@ -64,7 +63,7 @@ def generate_html(): for file in sorted(Path("characters").glob("*")): if file.suffix in [".json", ".yml", ".yaml"]: character = file.stem - container_html = f'
' + container_html = '
' image_html = "
" for i in [ diff --git a/modules/chat.py b/modules/chat.py index cd8639c2..2a76bddd 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -17,9 +17,9 @@ from modules.text_generation import (encode, generate_reply, get_max_prompt_length) -def generate_chat_output(history, name1, name2, character): +def generate_chat_output(history, name1, name2): if shared.args.cai_chat: - return generate_chat_html(history, name1, name2, character) + return generate_chat_html(history, name1, name2) else: return history @@ -180,22 +180,22 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts): - yield generate_chat_html(history, name1, name2, shared.character) + yield generate_chat_html(history, name1, name2) def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: - yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) + yield generate_chat_output(shared.history['visible'], name1, name2) else: last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' - yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character) + yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2) for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], history[-1][1]] else: shared.history['visible'][-1] = (last_visible[0], history[-1][1]) - yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) + yield generate_chat_output(shared.history['visible'], name1, name2) def remove_last_message(name1, name2): if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': @@ -205,7 +205,7 @@ def remove_last_message(name1, name2): last = ['', ''] if shared.args.cai_chat: - return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0] + return generate_chat_html(shared.history['visible'], name1, name2), last[0] else: return shared.history['visible'], last[0] @@ -223,10 +223,10 @@ def replace_last_reply(text, name1, name2): shared.history['visible'][-1] = (shared.history['visible'][-1][0], text) shared.history['internal'][-1][1] = apply_extensions(text, "input") - return generate_chat_output(shared.history['visible'], name1, name2, shared.character) + return generate_chat_output(shared.history['visible'], name1, name2) def clear_html(): - return generate_chat_html([], "", "", shared.character) + return generate_chat_html([], "", "") def clear_chat_log(name1, name2, greeting): shared.history['visible'] = [] @@ -236,10 +236,10 @@ def clear_chat_log(name1, name2, greeting): shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - return generate_chat_output(shared.history['visible'], name1, name2, shared.character) + return generate_chat_output(shared.history['visible'], name1, name2) def redraw_html(name1, name2): - return generate_chat_html(shared.history['visible'], name1, name2, shared.character) + return generate_chat_html(shared.history['visible'], name1, name2) def tokenize_dialogue(dialogue, name1, name2): history = [] @@ -326,13 +326,32 @@ def build_pygmalion_style_context(data): context = f"{context.strip()}\n\n" return context +def generate_pfp_cache(character): + cache_folder = Path("cache") + if not cache_folder.exists(): + cache_folder.mkdir() + + for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: + if path.exists(): + img = Image.open(path) + img.thumbnail((200, 200)) + img.save(Path('cache/pfp_character.png'), format='PNG') + return img + return None + def load_character(character, name1, name2): shared.character = character shared.history['internal'] = [] shared.history['visible'] = [] greeting = "" + picture = None + + # Deleting the profile picture cache, if any + if Path("cache/pfp_character.png").exists(): + Path("cache/pfp_character.png").unlink() if character != 'None': + picture = generate_pfp_cache(character) for extension in ["yml", "yaml", "json"]: filepath = Path(f'characters/{character}.{extension}') if filepath.exists(): @@ -371,9 +390,9 @@ def load_character(character, name1, name2): shared.history['visible'] += [['', apply_extensions(greeting, "output")]] if shared.args.cai_chat: - return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character) + return name1, name2, picture, greeting, context, generate_chat_html(shared.history['visible'], name1, name2) else: - return name1, name2, greeting, context, shared.history['visible'] + return name1, name2, picture, greeting, context, shared.history['visible'] def load_default_history(name1, name2): load_character("None", name1, name2) @@ -405,6 +424,14 @@ def upload_tavern_character(img, name1, name2): return upload_character(json.dumps(_json), img, tavern=True) def upload_your_profile_picture(img): - img = Image.open(io.BytesIO(img)) - img.save(Path('img_me.png')) - print('Profile picture saved to "img_me.png"') + cache_folder = Path("cache") + if not cache_folder.exists(): + cache_folder.mkdir() + + if img == None: + if Path("cache/pfp_me.png").exists(): + Path("cache/pfp_me.png").unlink() + else: + img.thumbnail((200, 200)) + img.save(Path('cache/pfp_me.png')) + print('Profile picture saved to "cache/pfp_me.png"') diff --git a/modules/html_generator.py b/modules/html_generator.py index 48d2e02e..a6b969b8 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -6,6 +6,7 @@ This is a library for formatting text outputs as nice HTML. import os import re +import time from pathlib import Path import markdown @@ -110,18 +111,12 @@ def get_image_cache(path): return image_cache[path][1] -def load_html_image(paths): - for str_path in paths: - path = Path(str_path) - if path.exists(): - return f'' - return '' - -def generate_chat_html(history, name1, name2, character): +def generate_chat_html(history, name1, name2): output = f'
' - img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"]) - img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"]) + # The time.time() is to prevent the brower from caching the image + img_bot = f'' if Path("cache/pfp_character.png").exists() else '' + img_me = f'' if Path("cache/pfp_me.png").exists() else '' for i,_row in enumerate(history[::-1]): row = [convert_to_markdown(entry) for entry in _row] diff --git a/modules/shared.py b/modules/shared.py index 038e392a..6c183a81 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma # Basic settings parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') -parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') +parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") diff --git a/server.py b/server.py index 0a837c50..914448a0 100644 --- a/server.py +++ b/server.py @@ -8,6 +8,7 @@ from datetime import datetime from pathlib import Path import gradio as gr +from PIL import Image import modules.extensions as extensions_module from modules import chat, shared, training, ui @@ -296,7 +297,7 @@ def create_interface(): shared.gradio['Chat input'] = gr.State() with gr.Tab("Text generation", elem_id="main"): if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character)) + shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'])) else: shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot") shared.gradio['textbox'] = gr.Textbox(label='Input') @@ -316,10 +317,15 @@ def create_interface(): shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) with gr.Tab("Character", elem_id="chat-settings"): - shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') - shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') - shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting') - shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context') + with gr.Row(): + with gr.Column(scale=8): + shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') + shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') + shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting') + shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context') + with gr.Column(scale=1): + shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil") + shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None) with gr.Row(): shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') @@ -347,8 +353,6 @@ def create_interface(): gr.Markdown("# TavernAI PNG format") shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) - with gr.Tab('Upload your profile picture'): - shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image']) with gr.Tab("Parameters", elem_id="parameters"): with gr.Box(): @@ -399,15 +403,15 @@ def create_interface(): shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']]) + shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) - shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], []) + shared.gradio['your_picture'].change(chat.upload_your_profile_picture, shared.gradio['your_picture'], []) reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) - shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['your_picture'].change(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")