Add refresh buttons for the model/preset/character menus

This commit is contained in:
oobabooga 2023-01-22 00:02:46 -03:00
parent bc664ecf3b
commit 434d4b128c
5 changed files with 73 additions and 18 deletions

View File

@ -150,3 +150,4 @@ Pull requests, suggestions, and issue reports are welcome.
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
- Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py - Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py
- Verbose preset: Anonymous 4chan user. - Verbose preset: Anonymous 4chan user.
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui

30
modules/ui.py Normal file
View File

@ -0,0 +1,30 @@
import gradio as gr
refresh_symbol = '\U0001f504' # 🔄
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self):
return "button"
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[refresh_component]
)
return refresh_button

View File

@ -1,18 +1,19 @@
import re import re
import gc
import time import time
import glob import glob
from sys import exit
import torch import torch
import argparse import argparse
import json import json
from sys import exit
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import transformers
from html_generator import *
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings import warnings
import gc
from tqdm import tqdm from tqdm import tqdm
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from modules.html_generator import *
from modules.ui import *
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -36,9 +37,18 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
args = parser.parse_args() args = parser.parse_args()
loaded_preset = None loaded_preset = None
available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower) def get_available_models():
available_presets = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
available_characters = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
def get_available_characters():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
available_models = get_available_models()
available_presets = get_available_presets()
available_characters = get_available_characters()
settings = { settings = {
'max_new_tokens': 200, 'max_new_tokens': 200,
@ -227,7 +237,7 @@ else:
default_text = settings['prompt'] default_text = settings['prompt']
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}" css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
if args.chat or args.cai_chat: if args.chat or args.cai_chat:
history = [] history = []
character = None character = None
@ -413,24 +423,30 @@ if args.chat or args.cai_chat:
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
with gr.Column(): with gr.Column():
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size']) history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size'])
preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset') with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name') name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name') name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context') context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
with gr.Row(): with gr.Row():
character_menu = gr.Dropdown(choices=["None"]+available_characters, value="None", label='Character') character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
with gr.Row(): with gr.Row():
check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
gr.Markdown("Upload chat history") gr.Markdown("Upload chat history", elem_id="upload-label")
upload = gr.File(type='binary') upload = gr.File(type='binary')
with gr.Column(): with gr.Column():
gr.Markdown("Download chat history") gr.Markdown("Download chat history", elem_id="download-label")
save_btn = gr.Button(value="Click me") save_btn = gr.Button(value="Click me")
download = gr.File() download = gr.File()
@ -473,9 +489,13 @@ elif args.notebook:
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
with gr.Column(): with gr.Column():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen") gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream) gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
@ -488,8 +508,12 @@ else:
with gr.Column(): with gr.Column():
textbox = gr.Textbox(value=default_text, lines=15, label='Input') textbox = gr.Textbox(value=default_text, lines=15, label='Input')
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
btn = gr.Button("Generate") btn = gr.Button("Generate")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():