mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-12 21:37:35 +01:00
Refactor everything (#3481)
This commit is contained in:
parent
d4b851bdc8
commit
65aa11890f
@ -3,7 +3,6 @@ import copy
|
||||
import functools
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
@ -64,7 +64,7 @@ class LlamacppHF(PreTrainedModel):
|
||||
else:
|
||||
self.model.eval([seq[-1]])
|
||||
|
||||
logits = torch.tensor(self.model.scores[self.model.n_tokens-1, :]).view(1, 1, -1).to(kwargs['input_ids'].device)
|
||||
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(kwargs['input_ids'].device)
|
||||
else:
|
||||
self.model.reset()
|
||||
self.model.eval(seq)
|
||||
@ -112,7 +112,7 @@ class LlamacppHF(PreTrainedModel):
|
||||
'use_mlock': shared.args.mlock,
|
||||
'low_vram': shared.args.low_vram,
|
||||
'n_gpu_layers': shared.args.n_gpu_layers,
|
||||
'rope_freq_base': 10000 * shared.args.alpha_value ** (64/63.),
|
||||
'rope_freq_base': 10000 * shared.args.alpha_value ** (64 / 63.),
|
||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
||||
'n_gqa': shared.args.n_gqa or None,
|
||||
'rms_norm_eps': shared.args.rms_norm_eps or None,
|
||||
|
@ -65,7 +65,7 @@ class LlamaCppModel:
|
||||
'use_mlock': shared.args.mlock,
|
||||
'low_vram': shared.args.low_vram,
|
||||
'n_gpu_layers': shared.args.n_gpu_layers,
|
||||
'rope_freq_base': 10000 * shared.args.alpha_value ** (64/63.),
|
||||
'rope_freq_base': 10000 * shared.args.alpha_value ** (64 / 63.),
|
||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
||||
'n_gqa': shared.args.n_gqa or None,
|
||||
'rms_norm_eps': shared.args.rms_norm_eps or None,
|
||||
|
@ -1,9 +1,9 @@
|
||||
import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@ -14,7 +14,7 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
BitsAndBytesConfig
|
||||
)
|
||||
|
||||
import modules.shared as shared
|
||||
|
@ -26,9 +26,9 @@ def infer_loader(model_name):
|
||||
loader = 'AutoGPTQ'
|
||||
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
|
||||
loader = 'llama.cpp'
|
||||
elif re.match('.*ggml.*\.bin', model_name.lower()):
|
||||
elif re.match(r'.*ggml.*\.bin', model_name.lower()):
|
||||
loader = 'llama.cpp'
|
||||
elif re.match('.*rwkv.*\.pth', model_name.lower()):
|
||||
elif re.match(r'.*rwkv.*\.pth', model_name.lower()):
|
||||
loader = 'RWKV'
|
||||
else:
|
||||
loader = 'Transformers'
|
||||
|
51
modules/prompts.py
Normal file
51
modules/prompts.py
Normal file
@ -0,0 +1,51 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from modules import utils
|
||||
from modules.text_generation import get_encoded_length
|
||||
|
||||
|
||||
def load_prompt(fname):
|
||||
if fname in ['None', '']:
|
||||
return ''
|
||||
elif fname.startswith('Instruct-'):
|
||||
fname = re.sub('^Instruct-', '', fname)
|
||||
file_path = Path(f'characters/instruction-following/{fname}.yaml')
|
||||
if not file_path.exists():
|
||||
return ''
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
output = ''
|
||||
if 'context' in data:
|
||||
output += data['context']
|
||||
|
||||
replacements = {
|
||||
'<|user|>': data['user'],
|
||||
'<|bot|>': data['bot'],
|
||||
'<|user-message|>': 'Input',
|
||||
}
|
||||
|
||||
output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements)
|
||||
return output.rstrip(' ')
|
||||
else:
|
||||
file_path = Path(f'prompts/{fname}.txt')
|
||||
if not file_path.exists():
|
||||
return ''
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
if text[-1] == '\n':
|
||||
text = text[:-1]
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def count_tokens(text):
|
||||
try:
|
||||
tokens = get_encoded_length(text)
|
||||
return f'{tokens} tokens in the input.'
|
||||
except:
|
||||
return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?'
|
@ -31,8 +31,62 @@ def generate_reply(*args, **kwargs):
|
||||
shared.generation_lock.release()
|
||||
|
||||
|
||||
def get_max_prompt_length(state):
|
||||
return state['truncation_length'] - state['max_new_tokens']
|
||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||
|
||||
# Find the appropriate generation function
|
||||
generate_func = apply_extensions('custom_generate_reply')
|
||||
if generate_func is None:
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logger.error("No model is loaded! Select one in the Model tab.")
|
||||
yield ''
|
||||
return
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
||||
generate_func = generate_reply_custom
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
|
||||
# Prepare the input
|
||||
original_question = question
|
||||
if not is_chat:
|
||||
state = apply_extensions('state', state)
|
||||
question = apply_extensions('input', question, state)
|
||||
|
||||
# Find the stopping strings
|
||||
all_stop_strings = []
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
if type(st) is list and len(st) > 0:
|
||||
all_stop_strings += st
|
||||
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
shared.stop_everything = False
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
last_update = -1
|
||||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
state = copy.deepcopy(state)
|
||||
state['stream'] = True
|
||||
|
||||
# Generate
|
||||
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
|
||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||
if is_stream:
|
||||
cur_time = time.time()
|
||||
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
||||
last_update = cur_time
|
||||
yield reply
|
||||
|
||||
if stop_found:
|
||||
break
|
||||
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply, state)
|
||||
|
||||
yield reply
|
||||
|
||||
|
||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||
@ -61,6 +115,10 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||
return input_ids.cuda()
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
||||
def get_encoded_length(prompt):
|
||||
length_after_extensions = apply_extensions('tokenized_length', prompt)
|
||||
if length_after_extensions is not None:
|
||||
@ -69,12 +127,36 @@ def get_encoded_length(prompt):
|
||||
return len(encode(prompt)[0])
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
def get_max_prompt_length(state):
|
||||
return state['truncation_length'] - state['max_new_tokens']
|
||||
|
||||
|
||||
def generate_reply_wrapper(question, state, stopping_strings=None):
|
||||
"""
|
||||
Returns formatted outputs for the UI
|
||||
"""
|
||||
reply = question if not shared.is_seq2seq else ''
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
for reply in generate_reply(question, state, stopping_strings, is_chat=False):
|
||||
if not shared.is_seq2seq:
|
||||
reply = question + reply
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, generate_4chan_html(reply)
|
||||
else:
|
||||
return reply, generate_basic_html(reply)
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
"""
|
||||
Removes empty replies from gpt4chan outputs
|
||||
"""
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
||||
@ -83,8 +165,10 @@ def fix_gpt4chan(s):
|
||||
return s
|
||||
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
"""
|
||||
Fix the LaTeX equations in GALACTICA
|
||||
"""
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
s = s.replace(r'\(', r'$')
|
||||
@ -109,14 +193,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
|
||||
return reply
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, generate_4chan_html(reply)
|
||||
else:
|
||||
return reply, generate_basic_html(reply)
|
||||
|
||||
|
||||
def set_manual_seed(seed):
|
||||
seed = int(seed)
|
||||
if seed == -1:
|
||||
@ -133,17 +209,6 @@ def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
|
||||
def generate_reply_wrapper(question, state, stopping_strings=None):
|
||||
reply = question if not shared.is_seq2seq else ''
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
for reply in generate_reply(question, state, stopping_strings, is_chat=False):
|
||||
if not shared.is_seq2seq:
|
||||
reply = question + reply
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
|
||||
def apply_stopping_strings(reply, all_stop_strings):
|
||||
stop_found = False
|
||||
for string in all_stop_strings:
|
||||
@ -169,61 +234,6 @@ def apply_stopping_strings(reply, all_stop_strings):
|
||||
return reply, stop_found
|
||||
|
||||
|
||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||
generate_func = apply_extensions('custom_generate_reply')
|
||||
if generate_func is None:
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logger.error("No model is loaded! Select one in the Model tab.")
|
||||
yield ''
|
||||
return
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
||||
generate_func = generate_reply_custom
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
|
||||
# Preparing the input
|
||||
original_question = question
|
||||
if not is_chat:
|
||||
state = apply_extensions('state', state)
|
||||
question = apply_extensions('input', question, state)
|
||||
|
||||
# Finding the stopping strings
|
||||
all_stop_strings = []
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
if type(st) is list and len(st) > 0:
|
||||
all_stop_strings += st
|
||||
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
shared.stop_everything = False
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
last_update = -1
|
||||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
state = copy.deepcopy(state)
|
||||
state['stream'] = True
|
||||
|
||||
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
|
||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||
if is_stream:
|
||||
cur_time = time.time()
|
||||
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
||||
last_update = cur_time
|
||||
yield reply
|
||||
|
||||
if stop_found:
|
||||
break
|
||||
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply, state)
|
||||
|
||||
yield reply
|
||||
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||
@ -316,6 +326,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
|
||||
|
||||
def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
"""
|
||||
For models that do not use the transformers library for sampling
|
||||
"""
|
||||
seed = set_manual_seed(state['seed'])
|
||||
|
||||
t0 = time.time()
|
||||
|
@ -17,8 +17,6 @@ from pathlib import Path
|
||||
import gradio as gr
|
||||
import torch
|
||||
import transformers
|
||||
from modules.models import load_model, unload_model
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
@ -34,6 +32,7 @@ from modules.evaluate import (
|
||||
save_past_evaluations
|
||||
)
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.utils import natural_keys
|
||||
|
||||
# This mapping is from a very recent commit, not yet released.
|
||||
@ -65,100 +64,101 @@ WANT_INTERRUPT = False
|
||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
||||
|
||||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
|
||||
|
||||
with gr.Row():
|
||||
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
|
||||
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
|
||||
|
||||
with gr.Row():
|
||||
copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=utils.get_available_loras())
|
||||
ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': utils.get_available_loras()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
# TODO: Implement multi-device support.
|
||||
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
|
||||
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
|
||||
|
||||
with gr.Row():
|
||||
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
|
||||
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
|
||||
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.')
|
||||
|
||||
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
|
||||
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.')
|
||||
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
||||
|
||||
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
|
||||
|
||||
with gr.Tab(label='Formatted Dataset'):
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
|
||||
format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button')
|
||||
|
||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||
|
||||
with gr.Tab(label="Raw text file"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
||||
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
|
||||
def create_ui():
|
||||
with gr.Tab("Training", elem_id="training-tab"):
|
||||
tmp = gr.State('')
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
|
||||
|
||||
with gr.Row():
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||
|
||||
with gr.Accordion(label='Advanced Options', open=False):
|
||||
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
||||
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
|
||||
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
|
||||
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
||||
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
||||
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut")
|
||||
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
|
||||
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
|
||||
|
||||
with gr.Row():
|
||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||
copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=utils.get_available_loras())
|
||||
ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': utils.get_available_loras()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
||||
# TODO: Implement multi-device support.
|
||||
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
|
||||
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
|
||||
|
||||
with gr.Row():
|
||||
start_button = gr.Button("Start LoRA Training")
|
||||
stop_button = gr.Button("Interrupt")
|
||||
with gr.Row():
|
||||
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
|
||||
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
|
||||
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.')
|
||||
|
||||
output = gr.Markdown(value="Ready")
|
||||
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
|
||||
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.')
|
||||
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
||||
|
||||
with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True)
|
||||
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
|
||||
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
|
||||
|
||||
with gr.Tab(label='Formatted Dataset'):
|
||||
with gr.Row():
|
||||
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
||||
max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
||||
dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
|
||||
format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button')
|
||||
|
||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||
|
||||
with gr.Tab(label="Raw text file"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
||||
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
|
||||
|
||||
with gr.Row():
|
||||
start_current_evaluation = gr.Button("Evaluate loaded model")
|
||||
start_evaluation = gr.Button("Evaluate selected models")
|
||||
stop_evaluation = gr.Button("Interrupt")
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||
|
||||
with gr.Column():
|
||||
evaluation_log = gr.Markdown(value='')
|
||||
with gr.Accordion(label='Advanced Options', open=False):
|
||||
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
||||
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
|
||||
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
|
||||
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
||||
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
||||
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut")
|
||||
|
||||
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
|
||||
with gr.Row():
|
||||
save_comments = gr.Button('Save comments', elem_classes="small-button")
|
||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
||||
with gr.Row():
|
||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||
with gr.Row():
|
||||
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
start_button = gr.Button("Start LoRA Training")
|
||||
stop_button = gr.Button("Interrupt")
|
||||
|
||||
output = gr.Markdown(value="Ready")
|
||||
|
||||
with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True)
|
||||
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
|
||||
with gr.Row():
|
||||
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
||||
max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
||||
|
||||
with gr.Row():
|
||||
start_current_evaluation = gr.Button("Evaluate loaded model")
|
||||
start_evaluation = gr.Button("Evaluate selected models")
|
||||
stop_evaluation = gr.Button("Interrupt")
|
||||
|
||||
with gr.Column():
|
||||
evaluation_log = gr.Markdown(value='')
|
||||
|
||||
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
|
||||
with gr.Row():
|
||||
save_comments = gr.Button('Save comments', elem_classes="small-button")
|
||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
||||
|
||||
# Training events
|
||||
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
||||
|
||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||
@ -172,7 +172,6 @@ def create_train_interface():
|
||||
ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
||||
start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
||||
|
||||
tmp = gr.State('')
|
||||
start_current_evaluation.click(lambda: ['current model'], None, tmp)
|
||||
ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
||||
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
@ -11,9 +10,9 @@ with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
||||
css = f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
|
||||
chat_css = f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
||||
with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f:
|
||||
main_js = f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/save_files.js', 'r') as f:
|
||||
with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r') as f:
|
||||
save_files_js = f.read()
|
||||
|
||||
refresh_symbol = '🔄'
|
||||
@ -30,6 +29,11 @@ theme = gr.themes.Default(
|
||||
background_fill_secondary='#eaeaea'
|
||||
)
|
||||
|
||||
if Path("notification.mp3").exists():
|
||||
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
|
||||
else:
|
||||
audio_notification_js = ""
|
||||
|
||||
|
||||
def list_model_elements():
|
||||
elements = [
|
||||
|
262
modules/ui_chat.py
Normal file
262
modules/ui_chat.py
Normal file
@ -0,0 +1,262 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
from modules import chat, shared, ui, utils
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
from modules.text_generation import stop_everything_event
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
def create_ui():
|
||||
|
||||
shared.gradio.update({
|
||||
'interface_state': gr.State({k: None for k in shared.input_elements}),
|
||||
'Chat input': gr.State(),
|
||||
'dummy': gr.State(),
|
||||
'history': gr.State({'internal': [], 'visible': []}),
|
||||
})
|
||||
|
||||
with gr.Tab('Text generation', elem_id='main'):
|
||||
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': []}, shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat'))
|
||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||
with gr.Row():
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
|
||||
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary')
|
||||
shared.gradio['Continue'] = gr.Button('Continue')
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||
shared.gradio['Remove last'] = gr.Button('Remove last', elem_classes=['button_nowrap'])
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||
shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
|
||||
shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['Clear history'] = gr.Button('Clear history')
|
||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False)
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['start_with'] = gr.Textbox(label='Start reply with', placeholder='Sure thing!', value=shared.settings['start_with'])
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'], value=shared.settings['mode'] if shared.settings['mode'] in ['chat', 'instruct', 'chat-instruct'] else 'chat', label='Mode', info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under "Chat settings" must match the current model.')
|
||||
shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct')
|
||||
|
||||
with gr.Tab('Chat settings', elem_id='chat-settings'):
|
||||
with gr.Tab("Character"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=8):
|
||||
with gr.Row():
|
||||
shared.gradio['character_menu'] = gr.Dropdown(value='None', choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button')
|
||||
shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button')
|
||||
shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button')
|
||||
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
|
||||
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context', elem_classes=['add_scrollbar'])
|
||||
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting', elem_classes=['add_scrollbar'])
|
||||
|
||||
with gr.Column(scale=1):
|
||||
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
|
||||
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None)
|
||||
|
||||
with gr.Tab("Instruction template"):
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button')
|
||||
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button')
|
||||
shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button')
|
||||
|
||||
shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string')
|
||||
shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string')
|
||||
shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context')
|
||||
shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.')
|
||||
with gr.Row():
|
||||
shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.', elem_classes=['add_scrollbar'])
|
||||
|
||||
with gr.Tab('Chat history'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['save_chat_history'] = gr.Button(value='Save history')
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['load_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'], label="Upload History JSON")
|
||||
|
||||
with gr.Tab('Upload character'):
|
||||
with gr.Tab('YAML or JSON'):
|
||||
with gr.Row():
|
||||
shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json', '.yaml'], label='JSON or YAML File')
|
||||
shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)')
|
||||
|
||||
shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False)
|
||||
|
||||
with gr.Tab('TavernAI PNG'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['upload_img_tavern'] = gr.Image(type='pil', label='TavernAI PNG File', elem_id="upload_img_tavern")
|
||||
shared.gradio['tavern_json'] = gr.State()
|
||||
with gr.Column():
|
||||
shared.gradio['tavern_name'] = gr.Textbox(value='', lines=1, label='Name', interactive=False)
|
||||
shared.gradio['tavern_desc'] = gr.Textbox(value='', lines=4, max_lines=4, label='Description', interactive=False)
|
||||
|
||||
shared.gradio['Submit tavern character'] = gr.Button(value='Submit', interactive=False)
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
gen_events = []
|
||||
|
||||
shared.input_params = gradio('Chat input', 'start_with', 'interface_state')
|
||||
clear_arr = gradio('Clear history-confirm', 'Clear history', 'Clear history-cancel')
|
||||
shared.reload_inputs = gradio('history', 'name1', 'name2', 'mode', 'chat_style')
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then(
|
||||
chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then(
|
||||
chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
partial(chat.generate_chat_reply_wrapper, regenerate=True), shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
partial(chat.generate_chat_reply_wrapper, _continue=True), shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Impersonate'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then(
|
||||
chat.impersonate_wrapper, shared.input_params, gradio('textbox'), show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
)
|
||||
|
||||
shared.gradio['Replace last reply'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.replace_last_reply, gradio('textbox', 'interface_state'), gradio('history')).then(
|
||||
lambda: '', None, gradio('textbox'), show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||
|
||||
shared.gradio['Send dummy message'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.send_dummy_message, gradio('textbox', 'interface_state'), gradio('history')).then(
|
||||
lambda: '', None, gradio('textbox'), show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||
|
||||
shared.gradio['Send dummy reply'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.send_dummy_reply, gradio('textbox', 'interface_state'), gradio('history')).then(
|
||||
lambda: '', None, gradio('textbox'), show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||
|
||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||
chat.clear_chat_log, gradio('interface_state'), gradio('history')).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||
|
||||
shared.gradio['Remove last'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.remove_last_message, gradio('history'), gradio('textbox', 'history'), show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||
|
||||
shared.gradio['character_menu'].change(
|
||||
partial(chat.load_character, instruct=False), gradio('character_menu', 'name1', 'name2'), gradio('name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy')).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.load_persistent_history, gradio('interface_state'), gradio('history')).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
shared.gradio['Stop'].click(
|
||||
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
shared.gradio['mode'].change(
|
||||
lambda x: gr.update(visible=x != 'instruct'), gradio('mode'), gradio('chat_style'), show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
shared.gradio['chat_style'].change(chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
shared.gradio['instruction_template'].change(
|
||||
partial(chat.load_character, instruct=True), gradio('instruction_template', 'name1_instruct', 'name2_instruct'), gradio('name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template'))
|
||||
|
||||
shared.gradio['load_chat_history'].upload(
|
||||
chat.load_history, gradio('load_chat_history', 'history'), gradio('history')).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||
None, None, None, _js='() => {alert("The history has been loaded.")}')
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, gradio('history'), gradio('textbox'), show_progress=False)
|
||||
|
||||
# Save/delete a character
|
||||
shared.gradio['save_character'].click(
|
||||
lambda x: x, gradio('name2'), gradio('save_character_filename')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('character_saver'))
|
||||
|
||||
shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'))
|
||||
|
||||
shared.gradio['save_template'].click(
|
||||
lambda: 'My Template.yaml', None, gradio('save_filename')).then(
|
||||
lambda: 'characters/instruction-following/', None, gradio('save_root')).then(
|
||||
chat.generate_instruction_template_yaml, gradio('name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template'), gradio('save_contents')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_saver'))
|
||||
|
||||
shared.gradio['delete_template'].click(
|
||||
lambda x: f'{x}.yaml', gradio('instruction_template'), gradio('delete_filename')).then(
|
||||
lambda: 'characters/instruction-following/', None, gradio('delete_root')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||
|
||||
shared.gradio['save_chat_history'].click(
|
||||
lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then(
|
||||
None, gradio('temporary_text', 'character_menu', 'mode'), None, _js=f"(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}")
|
||||
|
||||
shared.gradio['Submit character'].click(
|
||||
chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu')).then(
|
||||
None, None, None, _js='() => {alert("The character has been loaded.")}')
|
||||
|
||||
shared.gradio['Submit tavern character'].click(
|
||||
chat.upload_tavern_character, gradio('upload_img_tavern', 'tavern_json'), gradio('character_menu')).then(
|
||||
None, None, None, _js='() => {alert("The character has been loaded.")}')
|
||||
|
||||
shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character'))
|
||||
shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character'))
|
||||
shared.gradio['upload_img_tavern'].upload(chat.check_tavern_character, gradio('upload_img_tavern'), gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False)
|
||||
shared.gradio['upload_img_tavern'].clear(lambda: (None, None, None, gr.update(interactive=False)), None, gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False)
|
||||
shared.gradio['your_picture'].change(
|
||||
chat.upload_your_profile_picture, gradio('your_picture'), None).then(
|
||||
partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, gradio('display'))
|
94
modules/ui_default.py
Normal file
94
modules/ui_default.py
Normal file
@ -0,0 +1,94 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, ui, utils
|
||||
from modules.prompts import count_tokens, load_prompt
|
||||
from modules.text_generation import (
|
||||
generate_reply_wrapper,
|
||||
stop_everything_event
|
||||
)
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
def create_ui():
|
||||
default_text = load_prompt(shared.settings['prompt'])
|
||||
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
shared.gradio['last_input'] = gr.State('')
|
||||
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes=['textbox_default', 'add_scrollbar'], lines=27, label='Input')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
with gr.Row():
|
||||
shared.gradio['Generate'] = gr.Button('Generate', variant='primary')
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
|
||||
shared.gradio['Continue'] = gr.Button('Continue')
|
||||
shared.gradio['count_tokens'] = gr.Button('Count tokens')
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button')
|
||||
shared.gradio['save_prompt'] = gr.Button('💾', elem_classes='refresh-button')
|
||||
shared.gradio['delete_prompt'] = gr.Button('🗑️', elem_classes='refresh-button')
|
||||
|
||||
shared.gradio['status'] = gr.Markdown('')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['output_textbox'] = gr.Textbox(lines=27, label='Output', elem_classes=['textbox_default_output', 'add_scrollbar'])
|
||||
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown_render'] = gr.Button('Render')
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
gen_events = []
|
||||
shared.input_params = gradio('textbox', 'interface_state')
|
||||
output_params = gradio('output_textbox', 'html')
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
lambda x: x, gradio('textbox'), gradio('last_input')).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
lambda x: x, gradio('textbox'), gradio('last_input')).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
shared.gradio['markdown_render'].click(lambda x: x, gradio('output_textbox'), gradio('markdown'), queue=False)
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate_reply_wrapper, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['prompt_menu'].change(load_prompt, gradio('prompt_menu'), gradio('textbox'), show_progress=False)
|
||||
shared.gradio['save_prompt'].click(
|
||||
lambda x: x, gradio('textbox'), gradio('save_contents')).then(
|
||||
lambda: 'prompts/', None, gradio('save_root')).then(
|
||||
lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_saver'))
|
||||
|
||||
shared.gradio['delete_prompt'].click(
|
||||
lambda: 'prompts/', None, gradio('delete_root')).then(
|
||||
lambda x: x + '.txt', gradio('prompt_menu'), gradio('delete_filename')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||
|
||||
shared.gradio['count_tokens'].click(count_tokens, gradio('textbox'), gradio('status'), show_progress=False)
|
108
modules/ui_file_saving.py
Normal file
108
modules/ui_file_saving.py
Normal file
@ -0,0 +1,108 @@
|
||||
import json
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import chat, presets, shared, ui, utils
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
def create_ui():
|
||||
|
||||
# Text file saver
|
||||
with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']:
|
||||
shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name')
|
||||
shared.gradio['save_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False)
|
||||
shared.gradio['save_contents'] = gr.Textbox(lines=10, label='File contents')
|
||||
with gr.Row():
|
||||
shared.gradio['save_confirm'] = gr.Button('Save', elem_classes="small-button")
|
||||
shared.gradio['save_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||
|
||||
# Text file deleter
|
||||
with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_deleter']:
|
||||
shared.gradio['delete_filename'] = gr.Textbox(lines=1, label='File name')
|
||||
shared.gradio['delete_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False)
|
||||
with gr.Row():
|
||||
shared.gradio['delete_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop')
|
||||
shared.gradio['delete_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||
|
||||
# Character saver/deleter
|
||||
if shared.is_chat():
|
||||
with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']:
|
||||
shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info='The character will be saved to your characters/ folder with this base filename.')
|
||||
with gr.Row():
|
||||
shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button")
|
||||
shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||
|
||||
with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_deleter']:
|
||||
gr.Markdown('Confirm the character deletion?')
|
||||
with gr.Row():
|
||||
shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop')
|
||||
shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
shared.gradio['save_confirm'].click(
|
||||
lambda x, y, z: utils.save_file(x + y, z), gradio('save_root', 'save_filename', 'save_contents'), None).then(
|
||||
lambda: gr.update(visible=False), None, gradio('file_saver'))
|
||||
|
||||
shared.gradio['delete_confirm'].click(
|
||||
lambda x, y: utils.delete_file(x + y), gradio('delete_root', 'delete_filename'), None).then(
|
||||
lambda: gr.update(visible=False), None, gradio('file_deleter'))
|
||||
|
||||
shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_deleter'))
|
||||
shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_saver'))
|
||||
if shared.is_chat():
|
||||
shared.gradio['save_character_confirm'].click(
|
||||
chat.save_character, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), None).then(
|
||||
lambda: gr.update(visible=False), None, gradio('character_saver'))
|
||||
|
||||
shared.gradio['delete_character_confirm'].click(
|
||||
chat.delete_character, gradio('character_menu'), None).then(
|
||||
lambda: gr.update(visible=False), None, gradio('character_deleter')).then(
|
||||
lambda: gr.update(choices=utils.get_available_characters()), None, gradio('character_menu'))
|
||||
|
||||
shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'))
|
||||
shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'))
|
||||
|
||||
shared.gradio['save_preset'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
presets.generate_preset_yaml, gradio('interface_state'), gradio('save_contents')).then(
|
||||
lambda: 'presets/', None, gradio('save_root')).then(
|
||||
lambda: 'My Preset.yaml', None, gradio('save_filename')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_saver'))
|
||||
|
||||
shared.gradio['delete_preset'].click(
|
||||
lambda x: f'{x}.yaml', gradio('preset_menu'), gradio('delete_filename')).then(
|
||||
lambda: 'presets/', None, gradio('delete_root')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||
|
||||
if not shared.args.multi_user:
|
||||
shared.gradio['save_session'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('temporary_text')).then(
|
||||
None, gradio('temporary_text'), None, _js=f"(contents) => {{{ui.save_files_js}; saveSession(contents, \"{shared.get_mode()}\")}}")
|
||||
|
||||
if shared.is_chat():
|
||||
shared.gradio['load_session'].upload(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
load_session, gradio('load_session', 'interface_state'), gradio('interface_state')).then(
|
||||
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||
None, None, None, _js='() => {alert("The session has been loaded.")}')
|
||||
else:
|
||||
shared.gradio['load_session'].upload(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
load_session, gradio('load_session', 'interface_state'), gradio('interface_state')).then(
|
||||
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
||||
None, None, None, _js='() => {alert("The session has been loaded.")}')
|
||||
|
||||
|
||||
def load_session(file, state):
|
||||
decoded_file = file if type(file) == str else file.decode('utf-8')
|
||||
data = json.loads(decoded_file)
|
||||
|
||||
if shared.is_chat() and 'character_menu' in data and state.get('character_menu') != data.get('character_menu'):
|
||||
shared.session_is_loading = True
|
||||
|
||||
state.update(data)
|
||||
return state
|
229
modules/ui_model_menu.py
Normal file
229
modules/ui_model_menu.py
Normal file
@ -0,0 +1,229 @@
|
||||
import importlib
|
||||
import math
|
||||
import re
|
||||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import gradio as gr
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from modules import loaders, shared, ui, utils
|
||||
from modules.logging_colors import logger
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (
|
||||
apply_model_settings_to_state,
|
||||
save_model_settings,
|
||||
update_model_parameters
|
||||
)
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
def create_ui():
|
||||
# Finding the default values for the GPU and CPU memories
|
||||
total_mem = []
|
||||
for i in range(torch.cuda.device_count()):
|
||||
total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
|
||||
|
||||
default_gpu_mem = []
|
||||
if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
|
||||
for i in shared.args.gpu_memory:
|
||||
if 'mib' in i.lower():
|
||||
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
|
||||
else:
|
||||
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000)
|
||||
while len(default_gpu_mem) < len(total_mem):
|
||||
default_gpu_mem.append(0)
|
||||
|
||||
total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024))
|
||||
if shared.args.cpu_memory is not None:
|
||||
default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
|
||||
else:
|
||||
default_cpu_mem = 0
|
||||
|
||||
with gr.Tab("Model", elem_id="model-tab"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=shared.model_name, label='Model', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button')
|
||||
shared.gradio['load_model'] = gr.Button("Load", visible=not shared.settings['autoload_model'], elem_classes='refresh-button')
|
||||
shared.gradio['unload_model'] = gr.Button("Unload", elem_classes='refresh-button')
|
||||
shared.gradio['reload_model'] = gr.Button("Reload", elem_classes='refresh-button')
|
||||
shared.gradio['save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
|
||||
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value=None)
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
for i in range(len(total_mem)):
|
||||
shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i])
|
||||
|
||||
shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
|
||||
shared.gradio['transformers_info'] = gr.Markdown('load-in-4bit params:')
|
||||
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype)
|
||||
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type)
|
||||
|
||||
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers)
|
||||
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx)
|
||||
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads)
|
||||
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
|
||||
shared.gradio['n_gqa'] = gr.Slider(minimum=0, maximum=16, step=1, label="n_gqa", value=shared.args.n_gqa, info='grouped-query attention. Must be 8 for llama-2 70b.')
|
||||
shared.gradio['rms_norm_eps'] = gr.Slider(minimum=0, maximum=1e-5, step=1e-6, label="rms_norm_eps", value=shared.args.n_gqa, info='5e-6 is a good value for llama-2 models.')
|
||||
|
||||
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
|
||||
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")
|
||||
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
|
||||
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
|
||||
shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.')
|
||||
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=0, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len)
|
||||
shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb)
|
||||
shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=32, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Scaling is not identical to embedding compression. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value)
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
|
||||
shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention", value=shared.args.no_inject_fused_attention, info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Disable if running low on VRAM.')
|
||||
shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.')
|
||||
shared.gradio['no_use_cuda_fp16'] = gr.Checkbox(label="no_use_cuda_fp16", value=shared.args.no_use_cuda_fp16, info='This can make models faster on some systems.')
|
||||
shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
|
||||
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
|
||||
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
||||
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
|
||||
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
||||
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
||||
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
||||
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
|
||||
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
||||
shared.gradio['low_vram'] = gr.Checkbox(label="low-vram", value=shared.args.low_vram)
|
||||
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
||||
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
|
||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
||||
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
|
||||
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
|
||||
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')
|
||||
shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF is a wrapper that lets you use llama.cpp like a Transformers model, which means it can use the Transformers samplers. To use it, make sure to first download oobabooga/llama-tokenizer under "Download custom model or LoRA".')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.')
|
||||
|
||||
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
|
||||
shared.gradio['download_model_button'] = gr.Button("Download")
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()))
|
||||
|
||||
# In this event handler, the interface state is read and updated
|
||||
# with the model defaults (if any), and then the model is loaded
|
||||
# unless "autoload_model" is unchecked
|
||||
shared.gradio['model_menu'].change(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then(
|
||||
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
||||
update_model_parameters, gradio('interface_state'), None).then(
|
||||
load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False)
|
||||
|
||||
shared.gradio['load_model'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
update_model_parameters, gradio('interface_state'), None).then(
|
||||
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
|
||||
|
||||
shared.gradio['unload_model'].click(
|
||||
unload_model, None, None).then(
|
||||
lambda: "Model unloaded", None, gradio('model_status'))
|
||||
|
||||
shared.gradio['reload_model'].click(
|
||||
unload_model, None, None).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
update_model_parameters, gradio('interface_state'), None).then(
|
||||
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
|
||||
|
||||
shared.gradio['save_model_settings'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
|
||||
|
||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
|
||||
shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu'), gradio('model_status'), show_progress=True)
|
||||
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model'))
|
||||
|
||||
|
||||
def load_model_wrapper(selected_model, loader, autoload=False):
|
||||
if not autoload:
|
||||
yield f"The settings for {selected_model} have been updated.\nClick on \"Load\" to load it."
|
||||
return
|
||||
|
||||
if selected_model == 'None':
|
||||
yield "No model selected"
|
||||
else:
|
||||
try:
|
||||
yield f"Loading {selected_model}..."
|
||||
shared.model_name = selected_model
|
||||
unload_model()
|
||||
if selected_model != '':
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name, loader)
|
||||
|
||||
if shared.model is not None:
|
||||
yield f"Successfully loaded {selected_model}"
|
||||
else:
|
||||
yield f"Failed to load {selected_model}."
|
||||
except:
|
||||
exc = traceback.format_exc()
|
||||
logger.error('Failed to load the model.')
|
||||
print(exc)
|
||||
yield exc.replace('\n', '\n\n')
|
||||
|
||||
|
||||
def load_lora_wrapper(selected_loras):
|
||||
yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
|
||||
add_lora_to_model(selected_loras)
|
||||
yield ("Successfuly applied the LoRAs")
|
||||
|
||||
|
||||
def download_model_wrapper(repo_id, progress=gr.Progress()):
|
||||
try:
|
||||
downloader_module = importlib.import_module("download-model")
|
||||
downloader = downloader_module.ModelDownloader()
|
||||
repo_id_parts = repo_id.split(":")
|
||||
model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id
|
||||
branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
|
||||
check = False
|
||||
|
||||
progress(0.0)
|
||||
yield ("Cleaning up the model/branch names")
|
||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||
|
||||
yield ("Getting the download links from Hugging Face")
|
||||
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||
|
||||
yield ("Getting the output folder")
|
||||
base_folder = shared.args.lora_dir if is_lora else shared.args.model_dir
|
||||
output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=base_folder)
|
||||
|
||||
if check:
|
||||
progress(0.5)
|
||||
yield ("Checking previously downloaded files")
|
||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||
progress(1.0)
|
||||
else:
|
||||
yield (f"Downloading files to {output_folder}")
|
||||
downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1)
|
||||
yield ("Done!")
|
||||
except:
|
||||
progress(1.0)
|
||||
yield traceback.format_exc().replace('\n', '\n\n')
|
98
modules/ui_notebook.py
Normal file
98
modules/ui_notebook.py
Normal file
@ -0,0 +1,98 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, ui, utils
|
||||
from modules.prompts import count_tokens, load_prompt
|
||||
from modules.text_generation import (
|
||||
generate_reply_wrapper,
|
||||
stop_everything_event
|
||||
)
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
def create_ui():
|
||||
default_text = load_prompt(shared.settings['prompt'])
|
||||
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
shared.gradio['last_input'] = gr.State('')
|
||||
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes=['textbox', 'add_scrollbar'], lines=27)
|
||||
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown_render'] = gr.Button('Render')
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button", elem_id='stop')
|
||||
shared.gradio['Undo'] = gr.Button('Undo', elem_classes="small-button")
|
||||
shared.gradio['Regenerate'] = gr.Button('Regenerate', elem_classes="small-button")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.HTML('<div style="padding-bottom: 13px"></div>')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
with gr.Row():
|
||||
shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, ['refresh-button', 'refresh-button-small'])
|
||||
shared.gradio['save_prompt'] = gr.Button('💾', elem_classes=['refresh-button', 'refresh-button-small'])
|
||||
shared.gradio['delete_prompt'] = gr.Button('🗑️', elem_classes=['refresh-button', 'refresh-button-small'])
|
||||
|
||||
shared.gradio['count_tokens'] = gr.Button('Count tokens')
|
||||
shared.gradio['status'] = gr.Markdown('')
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
gen_events = []
|
||||
|
||||
shared.input_params = gradio('textbox', 'interface_state')
|
||||
output_params = gradio('textbox', 'html')
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
lambda x: x, gradio('textbox'), gradio('last_input')).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
lambda x: x, gradio('textbox'), gradio('last_input')).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
shared.gradio['Undo'].click(lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False)
|
||||
shared.gradio['markdown_render'].click(lambda x: x, gradio('textbox'), gradio('markdown'), queue=False)
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||
)
|
||||
|
||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['prompt_menu'].change(load_prompt, gradio('prompt_menu'), gradio('textbox'), show_progress=False)
|
||||
shared.gradio['save_prompt'].click(
|
||||
lambda x: x, gradio('textbox'), gradio('save_contents')).then(
|
||||
lambda: 'prompts/', None, gradio('save_root')).then(
|
||||
lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_saver'))
|
||||
|
||||
shared.gradio['delete_prompt'].click(
|
||||
lambda: 'prompts/', None, gradio('delete_root')).then(
|
||||
lambda x: x + '.txt', gradio('prompt_menu'), gradio('delete_filename')).then(
|
||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||
|
||||
shared.gradio['count_tokens'].click(count_tokens, gradio('textbox'), gradio('status'), show_progress=False)
|
143
modules/ui_parameters.py
Normal file
143
modules/ui_parameters.py
Normal file
@ -0,0 +1,143 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import loaders, presets, shared, ui, utils
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
def create_ui(default_preset):
|
||||
generate_params = presets.load_preset(default_preset)
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=default_preset, label='Generation parameters preset', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button')
|
||||
shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button')
|
||||
shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button')
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['filter_by_loader'] = gr.Dropdown(label="Filter by loader", choices=["All", "Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value="All", elem_classes='slim-dropdown')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
||||
shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff')
|
||||
shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff')
|
||||
shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')
|
||||
shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a')
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
|
||||
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length')
|
||||
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
|
||||
with gr.Accordion("Learn more", open=False):
|
||||
gr.Markdown("""
|
||||
|
||||
For a technical description of the parameters, the [transformers documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) is a good reference.
|
||||
|
||||
The best presets, according to the [Preset Arena](https://github.com/oobabooga/oobabooga.github.io/blob/main/arena/results.md) experiment, are:
|
||||
|
||||
* Instruction following:
|
||||
1) Divine Intellect
|
||||
2) Big O
|
||||
3) simple-1
|
||||
4) Space Alien
|
||||
5) StarChat
|
||||
6) Titanic
|
||||
7) tfs-with-top-a
|
||||
8) Asterism
|
||||
9) Contrastive Search
|
||||
|
||||
* Chat:
|
||||
1) Midnight Enigma
|
||||
2) Yara
|
||||
3) Shortwave
|
||||
|
||||
### Temperature
|
||||
Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.
|
||||
### top_p
|
||||
If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.
|
||||
### top_k
|
||||
Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.
|
||||
### typical_p
|
||||
If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
|
||||
### epsilon_cutoff
|
||||
In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. Should be used with top_p, top_k, and eta_cutoff set to 0.
|
||||
### eta_cutoff
|
||||
In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0.
|
||||
### repetition_penalty
|
||||
Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
|
||||
### repetition_penalty_range
|
||||
The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
|
||||
### encoder_repetition_penalty
|
||||
Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.
|
||||
### no_repeat_ngram_size
|
||||
If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.
|
||||
### min_length
|
||||
Minimum generation length in tokens.
|
||||
### penalty_alpha
|
||||
Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.
|
||||
|
||||
""", elem_classes="markdown")
|
||||
|
||||
with gr.Column():
|
||||
create_chat_settings_menus()
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.')
|
||||
shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt')
|
||||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.')
|
||||
|
||||
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
||||
with gr.Column():
|
||||
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||
|
||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False)
|
||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
|
||||
|
||||
|
||||
def create_chat_settings_menus():
|
||||
if not shared.is_chat():
|
||||
return
|
||||
|
||||
with gr.Box():
|
||||
gr.Markdown("Chat parameters")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations.')
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
|
71
modules/ui_session.py
Normal file
71
modules/ui_session.py
Normal file
@ -0,0 +1,71 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, ui, utils
|
||||
from modules.github import clone_or_pull_repository
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
def create_ui():
|
||||
with gr.Tab("Session", elem_id="session-tab"):
|
||||
modes = ["default", "notebook", "chat"]
|
||||
current_mode = "default"
|
||||
for mode in modes[1:]:
|
||||
if getattr(shared.args, mode):
|
||||
current_mode = mode
|
||||
break
|
||||
|
||||
cmd_list = vars(shared.args)
|
||||
bool_list = sorted([k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes + ui.list_model_elements()])
|
||||
bool_active = [k for k in bool_list if vars(shared.args)[k]]
|
||||
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode", elem_classes='slim-dropdown')
|
||||
shared.gradio['reset_interface'] = gr.Button("Apply and restart", elem_classes="small-button", variant="primary")
|
||||
shared.gradio['toggle_dark_mode'] = gr.Button('Toggle 💡', elem_classes="small-button")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=utils.get_available_extensions(), value=shared.args.extensions, label="Available extensions", info='Note that some of these extensions may require manually installing Python requirements through the command: pip install -r extensions/extension_name/requirements.txt', elem_classes='checkboxgroup-table')
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags", elem_classes='checkboxgroup-table')
|
||||
|
||||
with gr.Column():
|
||||
if not shared.args.multi_user:
|
||||
shared.gradio['save_session'] = gr.Button('Save session', elem_id="save_session")
|
||||
shared.gradio['load_session'] = gr.File(type='binary', file_types=['.json'], label="Upload Session JSON")
|
||||
|
||||
extension_name = gr.Textbox(lines=1, label='Install or update an extension', info='Enter the GitHub URL below and press Enter. For a list of extensions, see: https://github.com/oobabooga/text-generation-webui-extensions ⚠️ WARNING ⚠️ : extensions can execute arbitrary code. Make sure to inspect their source code before activating them.')
|
||||
extension_status = gr.Markdown()
|
||||
|
||||
extension_name.submit(
|
||||
clone_or_pull_repository, extension_name, extension_status, show_progress=False).then(
|
||||
lambda: gr.update(choices=utils.get_available_extensions(), value=shared.args.extensions), None, gradio('extensions_menu'))
|
||||
|
||||
# Reset interface event
|
||||
shared.gradio['reset_interface'].click(
|
||||
set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then(
|
||||
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;padding-top:20%;margin:0;height:100vh;color:lightgray;text-align:center;background:var(--body-background-fill)">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||
|
||||
shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}')
|
||||
|
||||
|
||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||
cmd_list = vars(shared.args)
|
||||
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||
|
||||
shared.args.extensions = extensions
|
||||
for k in modes[1:]:
|
||||
setattr(shared.args, k, False)
|
||||
if interface_mode != "default":
|
||||
setattr(shared.args, interface_mode, True)
|
||||
for k in bool_list:
|
||||
setattr(shared.args, k, False)
|
||||
for k in bool_active:
|
||||
setattr(shared.args, k, True)
|
||||
|
||||
shared.need_restart = True
|
Loading…
x
Reference in New Issue
Block a user