Add refresh buttons for the model/preset/character menus

This commit is contained in:
oobabooga 2023-01-22 00:02:46 -03:00
parent bc664ecf3b
commit 434d4b128c
5 changed files with 73 additions and 18 deletions

View File

@ -150,3 +150,4 @@ Pull requests, suggestions, and issue reports are welcome.
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
- Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py
- Verbose preset: Anonymous 4chan user.
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui

30
modules/ui.py Normal file
View File

@ -0,0 +1,30 @@
import gradio as gr
refresh_symbol = '\U0001f504' # 🔄
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self):
return "button"
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[refresh_component]
)
return refresh_button

View File

@ -1,18 +1,19 @@
import re
import gc
import time
import glob
from sys import exit
import torch
import argparse
import json
from sys import exit
from pathlib import Path
import gradio as gr
import transformers
from html_generator import *
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
import gc
from tqdm import tqdm
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from modules.html_generator import *
from modules.ui import *
transformers.logging.set_verbosity_error()
@ -36,9 +37,18 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
args = parser.parse_args()
loaded_preset = None
available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
available_presets = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
available_characters = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
def get_available_models():
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
def get_available_characters():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
available_models = get_available_models()
available_presets = get_available_presets()
available_characters = get_available_characters()
settings = {
'max_new_tokens': 200,
@ -227,7 +237,7 @@ else:
default_text = settings['prompt']
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}"
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
if args.chat or args.cai_chat:
history = []
character = None
@ -413,24 +423,30 @@ if args.chat or args.cai_chat:
with gr.Row():
with gr.Column():
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
with gr.Column():
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size'])
with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
with gr.Row():
character_menu = gr.Dropdown(choices=["None"]+available_characters, value="None", label='Character')
character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
with gr.Row():
check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
with gr.Row():
with gr.Column():
gr.Markdown("Upload chat history")
gr.Markdown("Upload chat history", elem_id="upload-label")
upload = gr.File(type='binary')
with gr.Column():
gr.Markdown("Download chat history")
gr.Markdown("Download chat history", elem_id="download-label")
save_btn = gr.Button(value="Click me")
download = gr.File()
@ -473,9 +489,13 @@ elif args.notebook:
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
with gr.Row():
with gr.Column():
with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
with gr.Column():
with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
@ -488,8 +508,12 @@ else:
with gr.Column():
textbox = gr.Textbox(value=default_text, lines=15, label='Input')
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
btn = gr.Button("Generate")
with gr.Row():
with gr.Column():