Add model settings to the Models tab

This commit is contained in:
oobabooga 2023-04-12 17:09:56 -03:00
parent 4f7e88c043
commit 1566d8e344
3 changed files with 136 additions and 60 deletions

View File

@ -41,6 +41,7 @@ settings = {
'truncation_length': 2048, 'truncation_length': 2048,
'truncation_length_min': 0, 'truncation_length_min': 0,
'truncation_length_max': 4096, 'truncation_length_max': 4096,
'mode': 'cai-chat',
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048, 'chat_prompt_size_max': 2048,
@ -115,9 +116,6 @@ parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quant
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.') parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.') parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.') parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.')
parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --wbits instead.')
parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
# FlexGen # FlexGen
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.') parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
@ -144,7 +142,7 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
args = parser.parse_args() args = parser.parse_args()
# Deprecation warnings for parameters that have been renamed # Deprecation warnings for parameters that have been renamed
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]} deprecated_dict = {}
for k in deprecated_dict: for k in deprecated_dict:
if eval(f"args.{k}") != deprecated_dict[k][1]: if eval(f"args.{k}") != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.") print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")

161
server.py
View File

@ -5,6 +5,7 @@ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
import importlib import importlib
import io import io
import json import json
import math
import os import os
import re import re
import sys import sys
@ -15,6 +16,8 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import psutil
import torch
from PIL import Image from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
@ -82,14 +85,16 @@ def get_available_loras():
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: try:
yield f"Loading {selected_model}..."
shared.model_name = selected_model shared.model_name = selected_model
unload_model() unload_model()
if selected_model != '': if selected_model != '':
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
return selected_model yield f"Successfully loaded {selected_model}"
except:
yield traceback.format_exc()
def load_lora_wrapper(selected_lora): def load_lora_wrapper(selected_lora):
@ -203,31 +208,117 @@ def download_model_wrapper(repo_id):
yield traceback.format_exc() yield traceback.format_exc()
def list_model_parameters():
return ['gpu_memory', 'cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
# Update the command-line arguments based on the interface values
def update_model_parameters(*args):
args = list(args)
elements = list_model_parameters()
for i, element in enumerate(elements):
if element in ['gpu_memory', 'cpu_memory'] and args[i] == 0:
args[i] = None
if element == 'wbits' and args[i] == 'None':
args[i] = 0
if element == 'groupsize' and args[i] == 'None':
args[i] = -1
if element == 'model_type' and args[i] == 'None':
args[i] = None
if element in ['wbits', 'groupsize', 'pre_layer']:
args[i] = int(args[i])
if element == 'gpu_memory' and args[i] is not None:
args[i] = [f"{args[i]}MiB"]
elif element == 'cpu_memory' and args[i] is not None:
args[i] = f"{args[i]}MiB"
#print(element, repr(eval(f"shared.args.{element}")), repr(args[i]))
#print(f"shared.args.{element} = args[i]")
exec(f"shared.args.{element} = args[i]")
#print()
def create_model_menus(): def create_model_menus():
# Finding the default values for the GPU and CPU memories
total_mem = math.floor(torch.cuda.get_device_properties(0).total_memory / (1024*1024))
total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024*1024))
if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
default_gpu_mem = re.sub('[a-zA-Z ]', '', shared.args.gpu_memory[0])
else:
default_gpu_mem = 0
if shared.args.cpu_memory is not None:
default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
else:
default_cpu_mem = 0
components = {}
with gr.Row():
with gr.Column():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button') ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA",
info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m")
with gr.Column():
shared.gradio['download_button'] = gr.Button("Download")
shared.gradio['download_status'] = gr.Markdown()
with gr.Column():
pass
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) with gr.Column():
unload = gr.Button("Unload the model")
reload = gr.Button("Reload the model")
with gr.Row():
with gr.Column():
with gr.Box():
with gr.Row():
with gr.Column():
components['gpu_memory'] = gr.Slider(label="gpu-memory in MiB", maximum=total_mem, value=default_gpu_mem)
components['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
with gr.Column():
components['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
components['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
components['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
components['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
components['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
with gr.Column():
with gr.Box():
with gr.Row():
with gr.Column():
components['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
components['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
with gr.Column():
components['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gpt-j"], value=shared.args.model_type or "None")
components['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer)
with gr.Row():
with gr.Column():
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m")
shared.gradio['download_button'] = gr.Button("Download")
with gr.Column():
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
shared.gradio['model_menu'].change(
update_model_parameters, [components[k] for k in list_model_parameters()], None).then(
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
unload.click(
unload_model, None, None).then(
lambda: "Model unloaded", None, shared.gradio['model_status'])
reload.click(
unload_model, None, None).then(
update_model_parameters, [components[k] for k in list_model_parameters()], None).then(
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False) shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
@ -333,20 +424,6 @@ else:
# Default model # Default model
if shared.args.model is not None: if shared.args.model is not None:
shared.model_name = shared.args.model shared.model_name = shared.args.model
else:
if len(available_models) == 0:
print('No models are available! Please download at least one.')
sys.exit(0)
elif len(available_models) == 1:
i = 0
else:
print('The following models are available:\n')
for i, model in enumerate(available_models):
print(f'{i+1}. {model}')
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
i = int(input()) - 1
print()
shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora: if shared.args.lora:
add_lora_to_model(shared.args.lora) add_lora_to_model(shared.args.lora)
@ -372,12 +449,12 @@ def create_interface():
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['Chat input'] = gr.State() shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"): with gr.Tab('Text generation', elem_id='main'):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row(): with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate') shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
with gr.Row(): with gr.Row():
shared.gradio['Regenerate'] = gr.Button('Regenerate') shared.gradio['Regenerate'] = gr.Button('Regenerate')
shared.gradio['Continue'] = gr.Button('Continue') shared.gradio['Continue'] = gr.Button('Continue')
@ -389,24 +466,24 @@ def create_interface():
shared.gradio['Copy last reply'] = gr.Button('Copy last reply') shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
with gr.Row(): with gr.Row():
shared.gradio['Clear history'] = gr.Button('Clear history') shared.gradio['Clear history'] = gr.Button('Clear history')
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio['Remove last'] = gr.Button('Remove last') shared.gradio['Remove last'] = gr.Button('Remove last')
shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode") shared.gradio['mode'] = gr.Radio(choices=['cai-chat', 'chat', 'instruct'], value=shared.settings['mode'], label='Mode')
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.") shared.gradio['Instruction templates'] = gr.Dropdown(choices=get_available_instruction_templates(), label='Instruction template', value='None', visible=False, info='Change this according to the model/LoRA that you are using.')
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab('Character', elem_id='chat-settings'):
with gr.Row(): with gr.Row():
with gr.Column(scale=8): with gr.Column(scale=8):
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting') shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context') shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string') shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings['end_of_turn'], lines=1, label='End of turn string')
with gr.Column(scale=1): with gr.Column(scale=1):
shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil") shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None) shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None)
with gr.Row(): with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button') ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
@ -422,7 +499,7 @@ def create_interface():
shared.gradio['download'] = gr.File() shared.gradio['download'] = gr.File()
shared.gradio['download_button'] = gr.Button(value='Click me') shared.gradio['download_button'] = gr.Button(value='Click me')
with gr.Tab('Upload character'): with gr.Tab('Upload character'):
gr.Markdown("# JSON format") gr.Markdown('# JSON format')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
gr.Markdown('1. Select the JSON file') gr.Markdown('1. Select the JSON file')
@ -432,7 +509,7 @@ def create_interface():
shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image']) shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
shared.gradio['Upload character'] = gr.Button(value='Submit') shared.gradio['Upload character'] = gr.Button(value='Submit')
gr.Markdown("# TavernAI PNG format") gr.Markdown('# TavernAI PNG format')
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
@ -648,7 +725,7 @@ def create_interface():
current_mode = mode current_mode = mode
break break
cmd_list = vars(shared.args) cmd_list = vars(shared.args)
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes + list_model_parameters()]
bool_active = [k for k in bool_list if vars(shared.args)[k]] bool_active = [k for k in bool_list if vars(shared.args)[k]]
gr.Markdown("*Experimental*") gr.Markdown("*Experimental*")

View File

@ -6,7 +6,7 @@
"name1": "You", "name1": "You",
"name2": "Assistant", "name2": "Assistant",
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.", "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
"greeting": "Hello there!", "greeting": "",
"end_of_turn": "", "end_of_turn": "",
"custom_stopping_strings": "", "custom_stopping_strings": "",
"stop_at_newline": false, "stop_at_newline": false,
@ -15,6 +15,7 @@
"truncation_length": 2048, "truncation_length": 2048,
"truncation_length_min": 0, "truncation_length_min": 0,
"truncation_length_max": 4096, "truncation_length_max": 4096,
"mode": "cai-chat",
"chat_prompt_size": 2048, "chat_prompt_size": 2048,
"chat_prompt_size_min": 0, "chat_prompt_size_min": 0,
"chat_prompt_size_max": 2048, "chat_prompt_size_max": 2048,