From 7c0a17962d1404d181d69319bdcb6d60f0dc3c39 Mon Sep 17 00:00:00 2001 From: Lounger <4087076+TheLounger@users.noreply.github.com> Date: Mon, 4 Dec 2023 02:45:50 +0100 Subject: [PATCH] Gallery improvements (#4789) --- extensions/gallery/script.js | 7 +++++++ extensions/gallery/script.py | 34 +++++++++++++++++++++++++++++----- modules/chat.py | 8 ++++---- modules/html_generator.py | 10 +++++----- modules/shared.py | 2 ++ modules/ui_chat.py | 4 ++-- settings-template.yaml | 2 ++ 7 files changed, 51 insertions(+), 16 deletions(-) diff --git a/extensions/gallery/script.js b/extensions/gallery/script.js index 4ff23afc..9717aa67 100644 --- a/extensions/gallery/script.js +++ b/extensions/gallery/script.js @@ -5,6 +5,13 @@ let extensions_block = document.getElementById('extensions'); let extensions_block_size = extensions_block.childNodes.length; let gallery_only = (extensions_block_size == 5); +function gotoFirstPage() { + const firstPageButton = gallery_element.querySelector('.paginate > button'); + if (firstPageButton) { + firstPageButton.click(); + } +} + document.querySelector('.header_bar').addEventListener('click', function(event) { if (event.target.tagName === 'BUTTON') { const buttonText = event.target.textContent.trim(); diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index efe96ba9..d520abb5 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -3,11 +3,17 @@ from pathlib import Path import gradio as gr from modules.html_generator import get_image_cache -from modules.shared import gradio +from modules.shared import gradio, settings + +cards = [] def generate_css(): css = """ + .highlighted-border { + border-color: rgb(249, 115, 22) !important; + } + .character-gallery > .gallery { margin: 1rem 0; display: grid !important; @@ -58,6 +64,7 @@ def generate_css(): def generate_html(): + global cards cards = [] # Iterate through files in image folder for file in sorted(Path("characters").glob("*")): @@ -78,6 +85,14 @@ def generate_html(): return cards +def filter_cards(filter_str=''): + if filter_str == '': + return cards + + filter_upper = filter_str.upper() + return [k for k in cards if filter_upper in k[1].upper()] + + def select_character(evt: gr.SelectData): return (evt.value[1]) @@ -88,16 +103,25 @@ def custom_js(): def ui(): - with gr.Accordion("Character gallery", open=False, elem_id='gallery-extension'): - update = gr.Button("Refresh") + with gr.Accordion("Character gallery", open=settings["gallery-open"], elem_id='gallery-extension'): gr.HTML(value="") + with gr.Row(): + filter_box = gr.Textbox(label='', placeholder='Filter', lines=1, max_lines=1, container=False, elem_id='gallery-filter-box') + gr.ClearButton(filter_box, value='🗑️', elem_classes='refresh-button') + update = gr.Button("Refresh", elem_classes='refresh-button') gallery = gr.Dataset( components=[gr.HTML(visible=False)], label="", samples=generate_html(), elem_classes=["character-gallery"], - samples_per_page=50 + samples_per_page=settings["gallery-items_per_page"] ) - update.click(generate_html, [], gallery) + filter_box.change(lambda: None, None, None, _js=f'() => {{{custom_js()}; gotoFirstPage()}}').success( + filter_cards, filter_box, gallery).then( + lambda x: gr.update(elem_classes='highlighted-border' if x != '' else ''), filter_box, filter_box, show_progress=False) + + update.click(generate_html, [], None).success( + filter_cards, filter_box, gallery) + gallery.select(select_character, None, gradio['character_menu']) diff --git a/modules/chat.py b/modules/chat.py index 22b5bf9a..e126a428 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -281,7 +281,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess def impersonate_wrapper(text, state): - static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style']) + static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) if shared.model_name == 'None' or shared.model is None: logger.error("No model is loaded! Select one in the Model tab.") @@ -340,7 +340,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): send_dummy_reply(state['start_with'], state) for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True)): - yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style']), history + yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history def remove_last_message(history): @@ -390,8 +390,8 @@ def send_dummy_reply(text, state): return history -def redraw_html(history, name1, name2, mode, style, reset_cache=False): - return chat_html_wrapper(history, name1, name2, mode, style, reset_cache=reset_cache) +def redraw_html(history, name1, name2, mode, style, character, reset_cache=False): + return chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=reset_cache) def start_new_chat(state): diff --git a/modules/html_generator.py b/modules/html_generator.py index 2a6509b3..789cf639 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -221,11 +221,11 @@ def generate_instruct_html(history): return output -def generate_cai_chat_html(history, name1, name2, style, reset_cache=False): +def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=False): output = f'