diff --git a/README.md b/README.md index 2ab0336c..0b2b3297 100644 --- a/README.md +++ b/README.md @@ -150,3 +150,4 @@ Pull requests, suggestions, and issue reports are welcome. - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets - Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py - Verbose preset: Anonymous 4chan user. +- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui diff --git a/html_generator.py b/modules/html_generator.py similarity index 100% rename from html_generator.py rename to modules/html_generator.py diff --git a/modules/ui.py b/modules/ui.py new file mode 100644 index 00000000..0630ebf1 --- /dev/null +++ b/modules/ui.py @@ -0,0 +1,30 @@ +import gradio as gr + +refresh_symbol = '\U0001f504' # 🔄 + +class ToolButton(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool", **kwargs) + + def get_block_name(self): + return "button" + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button diff --git a/server.py b/server.py index 67320377..5ac1d20b 100644 --- a/server.py +++ b/server.py @@ -1,18 +1,19 @@ import re +import gc import time import glob -from sys import exit import torch import argparse import json +from sys import exit from pathlib import Path import gradio as gr -import transformers -from html_generator import * -from transformers import AutoTokenizer, AutoModelForCausalLM import warnings -import gc from tqdm import tqdm +import transformers +from transformers import AutoTokenizer, AutoModelForCausalLM +from modules.html_generator import * +from modules.ui import * transformers.logging.set_verbosity_error() @@ -36,9 +37,18 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T args = parser.parse_args() loaded_preset = None -available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower) -available_presets = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) -available_characters = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) +def get_available_models(): + return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower) + +def get_available_presets(): + return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) + +def get_available_characters(): + return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) + +available_models = get_available_models() +available_presets = get_available_presets() +available_characters = get_available_characters() settings = { 'max_new_tokens': 200, @@ -227,7 +237,7 @@ else: default_text = settings['prompt'] description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" -css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}" +css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}" if args.chat or args.cai_chat: history = [] character = None @@ -413,24 +423,30 @@ if args.chat or args.cai_chat: with gr.Row(): with gr.Column(): length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) - model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') + with gr.Row(): + model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') + create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") with gr.Column(): history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size']) - preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset') + with gr.Row(): + preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset') + create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name') name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name') context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context') with gr.Row(): - character_menu = gr.Dropdown(choices=["None"]+available_characters, value="None", label='Character') + character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character') + create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button") + with gr.Row(): check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') with gr.Row(): with gr.Column(): - gr.Markdown("Upload chat history") + gr.Markdown("Upload chat history", elem_id="upload-label") upload = gr.File(type='binary') with gr.Column(): - gr.Markdown("Download chat history") + gr.Markdown("Download chat history", elem_id="download-label") save_btn = gr.Button(value="Click me") download = gr.File() @@ -473,9 +489,13 @@ elif args.notebook: length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) with gr.Row(): with gr.Column(): - model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') + with gr.Row(): + model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') + create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") with gr.Column(): - preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') + with gr.Row(): + preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') + create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen") gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream) @@ -488,8 +508,12 @@ else: with gr.Column(): textbox = gr.Textbox(value=default_text, lines=15, label='Input') length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) - preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') - model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') + with gr.Row(): + preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') + create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button") + with gr.Row(): + model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') + create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button") btn = gr.Button("Generate") with gr.Row(): with gr.Column(): diff --git a/torch-dumps/place-your-pt-models-here.txt b/torch-dumps/place-your-pt-models-here.txt new file mode 100644 index 00000000..e69de29b