mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Refactor everything (#3481)
This commit is contained in:
parent
d4b851bdc8
commit
65aa11890f
@ -3,7 +3,6 @@ import copy
|
|||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import gc
|
import gc
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import hashlib
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@ -14,7 +14,7 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -26,9 +26,9 @@ def infer_loader(model_name):
|
|||||||
loader = 'AutoGPTQ'
|
loader = 'AutoGPTQ'
|
||||||
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
|
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
|
||||||
loader = 'llama.cpp'
|
loader = 'llama.cpp'
|
||||||
elif re.match('.*ggml.*\.bin', model_name.lower()):
|
elif re.match(r'.*ggml.*\.bin', model_name.lower()):
|
||||||
loader = 'llama.cpp'
|
loader = 'llama.cpp'
|
||||||
elif re.match('.*rwkv.*\.pth', model_name.lower()):
|
elif re.match(r'.*rwkv.*\.pth', model_name.lower()):
|
||||||
loader = 'RWKV'
|
loader = 'RWKV'
|
||||||
else:
|
else:
|
||||||
loader = 'Transformers'
|
loader = 'Transformers'
|
||||||
|
51
modules/prompts.py
Normal file
51
modules/prompts.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from modules import utils
|
||||||
|
from modules.text_generation import get_encoded_length
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompt(fname):
|
||||||
|
if fname in ['None', '']:
|
||||||
|
return ''
|
||||||
|
elif fname.startswith('Instruct-'):
|
||||||
|
fname = re.sub('^Instruct-', '', fname)
|
||||||
|
file_path = Path(f'characters/instruction-following/{fname}.yaml')
|
||||||
|
if not file_path.exists():
|
||||||
|
return ''
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
output = ''
|
||||||
|
if 'context' in data:
|
||||||
|
output += data['context']
|
||||||
|
|
||||||
|
replacements = {
|
||||||
|
'<|user|>': data['user'],
|
||||||
|
'<|bot|>': data['bot'],
|
||||||
|
'<|user-message|>': 'Input',
|
||||||
|
}
|
||||||
|
|
||||||
|
output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements)
|
||||||
|
return output.rstrip(' ')
|
||||||
|
else:
|
||||||
|
file_path = Path(f'prompts/{fname}.txt')
|
||||||
|
if not file_path.exists():
|
||||||
|
return ''
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
text = f.read()
|
||||||
|
if text[-1] == '\n':
|
||||||
|
text = text[:-1]
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(text):
|
||||||
|
try:
|
||||||
|
tokens = get_encoded_length(text)
|
||||||
|
return f'{tokens} tokens in the input.'
|
||||||
|
except:
|
||||||
|
return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?'
|
@ -31,8 +31,62 @@ def generate_reply(*args, **kwargs):
|
|||||||
shared.generation_lock.release()
|
shared.generation_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def get_max_prompt_length(state):
|
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||||
return state['truncation_length'] - state['max_new_tokens']
|
|
||||||
|
# Find the appropriate generation function
|
||||||
|
generate_func = apply_extensions('custom_generate_reply')
|
||||||
|
if generate_func is None:
|
||||||
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
|
logger.error("No model is loaded! Select one in the Model tab.")
|
||||||
|
yield ''
|
||||||
|
return
|
||||||
|
|
||||||
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
||||||
|
generate_func = generate_reply_custom
|
||||||
|
else:
|
||||||
|
generate_func = generate_reply_HF
|
||||||
|
|
||||||
|
# Prepare the input
|
||||||
|
original_question = question
|
||||||
|
if not is_chat:
|
||||||
|
state = apply_extensions('state', state)
|
||||||
|
question = apply_extensions('input', question, state)
|
||||||
|
|
||||||
|
# Find the stopping strings
|
||||||
|
all_stop_strings = []
|
||||||
|
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||||
|
if type(st) is list and len(st) > 0:
|
||||||
|
all_stop_strings += st
|
||||||
|
|
||||||
|
if shared.args.verbose:
|
||||||
|
print(f'\n\n{question}\n--------------------\n')
|
||||||
|
|
||||||
|
shared.stop_everything = False
|
||||||
|
clear_torch_cache()
|
||||||
|
seed = set_manual_seed(state['seed'])
|
||||||
|
last_update = -1
|
||||||
|
reply = ''
|
||||||
|
is_stream = state['stream']
|
||||||
|
if len(all_stop_strings) > 0 and not state['stream']:
|
||||||
|
state = copy.deepcopy(state)
|
||||||
|
state['stream'] = True
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
|
||||||
|
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||||
|
if is_stream:
|
||||||
|
cur_time = time.time()
|
||||||
|
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
||||||
|
last_update = cur_time
|
||||||
|
yield reply
|
||||||
|
|
||||||
|
if stop_found:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not is_chat:
|
||||||
|
reply = apply_extensions('output', reply, state)
|
||||||
|
|
||||||
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
@ -61,6 +115,10 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
return input_ids.cuda()
|
return input_ids.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def decode(output_ids, skip_special_tokens=True):
|
||||||
|
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
def get_encoded_length(prompt):
|
def get_encoded_length(prompt):
|
||||||
length_after_extensions = apply_extensions('tokenized_length', prompt)
|
length_after_extensions = apply_extensions('tokenized_length', prompt)
|
||||||
if length_after_extensions is not None:
|
if length_after_extensions is not None:
|
||||||
@ -69,12 +127,36 @@ def get_encoded_length(prompt):
|
|||||||
return len(encode(prompt)[0])
|
return len(encode(prompt)[0])
|
||||||
|
|
||||||
|
|
||||||
def decode(output_ids, skip_special_tokens=True):
|
def get_max_prompt_length(state):
|
||||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
return state['truncation_length'] - state['max_new_tokens']
|
||||||
|
|
||||||
|
|
||||||
|
def generate_reply_wrapper(question, state, stopping_strings=None):
|
||||||
|
"""
|
||||||
|
Returns formatted outputs for the UI
|
||||||
|
"""
|
||||||
|
reply = question if not shared.is_seq2seq else ''
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
for reply in generate_reply(question, state, stopping_strings, is_chat=False):
|
||||||
|
if not shared.is_seq2seq:
|
||||||
|
reply = question + reply
|
||||||
|
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def formatted_outputs(reply, model_name):
|
||||||
|
if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
|
||||||
|
reply = fix_gpt4chan(reply)
|
||||||
|
return reply, generate_4chan_html(reply)
|
||||||
|
else:
|
||||||
|
return reply, generate_basic_html(reply)
|
||||||
|
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
|
"""
|
||||||
|
Removes empty replies from gpt4chan outputs
|
||||||
|
"""
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||||
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
||||||
@ -83,8 +165,10 @@ def fix_gpt4chan(s):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# Fix the LaTeX equations in galactica
|
|
||||||
def fix_galactica(s):
|
def fix_galactica(s):
|
||||||
|
"""
|
||||||
|
Fix the LaTeX equations in GALACTICA
|
||||||
|
"""
|
||||||
s = s.replace(r'\[', r'$')
|
s = s.replace(r'\[', r'$')
|
||||||
s = s.replace(r'\]', r'$')
|
s = s.replace(r'\]', r'$')
|
||||||
s = s.replace(r'\(', r'$')
|
s = s.replace(r'\(', r'$')
|
||||||
@ -109,14 +193,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
|
|||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def formatted_outputs(reply, model_name):
|
|
||||||
if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
|
|
||||||
reply = fix_gpt4chan(reply)
|
|
||||||
return reply, generate_4chan_html(reply)
|
|
||||||
else:
|
|
||||||
return reply, generate_basic_html(reply)
|
|
||||||
|
|
||||||
|
|
||||||
def set_manual_seed(seed):
|
def set_manual_seed(seed):
|
||||||
seed = int(seed)
|
seed = int(seed)
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
@ -133,17 +209,6 @@ def stop_everything_event():
|
|||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
|
|
||||||
def generate_reply_wrapper(question, state, stopping_strings=None):
|
|
||||||
reply = question if not shared.is_seq2seq else ''
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
|
||||||
|
|
||||||
for reply in generate_reply(question, state, stopping_strings, is_chat=False):
|
|
||||||
if not shared.is_seq2seq:
|
|
||||||
reply = question + reply
|
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_stopping_strings(reply, all_stop_strings):
|
def apply_stopping_strings(reply, all_stop_strings):
|
||||||
stop_found = False
|
stop_found = False
|
||||||
for string in all_stop_strings:
|
for string in all_stop_strings:
|
||||||
@ -169,61 +234,6 @@ def apply_stopping_strings(reply, all_stop_strings):
|
|||||||
return reply, stop_found
|
return reply, stop_found
|
||||||
|
|
||||||
|
|
||||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
|
||||||
generate_func = apply_extensions('custom_generate_reply')
|
|
||||||
if generate_func is None:
|
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
|
||||||
logger.error("No model is loaded! Select one in the Model tab.")
|
|
||||||
yield ''
|
|
||||||
return
|
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
|
||||||
generate_func = generate_reply_custom
|
|
||||||
else:
|
|
||||||
generate_func = generate_reply_HF
|
|
||||||
|
|
||||||
# Preparing the input
|
|
||||||
original_question = question
|
|
||||||
if not is_chat:
|
|
||||||
state = apply_extensions('state', state)
|
|
||||||
question = apply_extensions('input', question, state)
|
|
||||||
|
|
||||||
# Finding the stopping strings
|
|
||||||
all_stop_strings = []
|
|
||||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
|
||||||
if type(st) is list and len(st) > 0:
|
|
||||||
all_stop_strings += st
|
|
||||||
|
|
||||||
if shared.args.verbose:
|
|
||||||
print(f'\n\n{question}\n--------------------\n')
|
|
||||||
|
|
||||||
shared.stop_everything = False
|
|
||||||
clear_torch_cache()
|
|
||||||
seed = set_manual_seed(state['seed'])
|
|
||||||
last_update = -1
|
|
||||||
reply = ''
|
|
||||||
is_stream = state['stream']
|
|
||||||
if len(all_stop_strings) > 0 and not state['stream']:
|
|
||||||
state = copy.deepcopy(state)
|
|
||||||
state['stream'] = True
|
|
||||||
|
|
||||||
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
|
|
||||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
|
||||||
if is_stream:
|
|
||||||
cur_time = time.time()
|
|
||||||
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
|
||||||
last_update = cur_time
|
|
||||||
yield reply
|
|
||||||
|
|
||||||
if stop_found:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not is_chat:
|
|
||||||
reply = apply_extensions('output', reply, state)
|
|
||||||
|
|
||||||
yield reply
|
|
||||||
|
|
||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||||
@ -316,6 +326,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
|
|
||||||
|
|
||||||
def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
|
"""
|
||||||
|
For models that do not use the transformers library for sampling
|
||||||
|
"""
|
||||||
seed = set_manual_seed(state['seed'])
|
seed = set_manual_seed(state['seed'])
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
@ -17,8 +17,6 @@ from pathlib import Path
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from modules.models import load_model, unload_model
|
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import (
|
from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
@ -34,6 +32,7 @@ from modules.evaluate import (
|
|||||||
save_past_evaluations
|
save_past_evaluations
|
||||||
)
|
)
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
from modules.utils import natural_keys
|
from modules.utils import natural_keys
|
||||||
|
|
||||||
# This mapping is from a very recent commit, not yet released.
|
# This mapping is from a very recent commit, not yet released.
|
||||||
@ -65,7 +64,9 @@ WANT_INTERRUPT = False
|
|||||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
||||||
|
|
||||||
|
|
||||||
def create_train_interface():
|
def create_ui():
|
||||||
|
with gr.Tab("Training", elem_id="training-tab"):
|
||||||
|
tmp = gr.State('')
|
||||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||||
gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
|
gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
|
||||||
|
|
||||||
@ -158,7 +159,6 @@ def create_train_interface():
|
|||||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
||||||
|
|
||||||
# Training events
|
# Training events
|
||||||
|
|
||||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
||||||
|
|
||||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||||
@ -172,7 +172,6 @@ def create_train_interface():
|
|||||||
ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
||||||
start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
||||||
|
|
||||||
tmp = gr.State('')
|
|
||||||
start_current_evaluation.click(lambda: ['current model'], None, tmp)
|
start_current_evaluation.click(lambda: ['current model'], None, tmp)
|
||||||
ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
||||||
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -11,9 +10,9 @@ with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
|||||||
css = f.read()
|
css = f.read()
|
||||||
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
|
||||||
chat_css = f.read()
|
chat_css = f.read()
|
||||||
with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f:
|
||||||
main_js = f.read()
|
main_js = f.read()
|
||||||
with open(Path(__file__).resolve().parent / '../css/save_files.js', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r') as f:
|
||||||
save_files_js = f.read()
|
save_files_js = f.read()
|
||||||
|
|
||||||
refresh_symbol = '🔄'
|
refresh_symbol = '🔄'
|
||||||
@ -30,6 +29,11 @@ theme = gr.themes.Default(
|
|||||||
background_fill_secondary='#eaeaea'
|
background_fill_secondary='#eaeaea'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if Path("notification.mp3").exists():
|
||||||
|
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
|
||||||
|
else:
|
||||||
|
audio_notification_js = ""
|
||||||
|
|
||||||
|
|
||||||
def list_model_elements():
|
def list_model_elements():
|
||||||
elements = [
|
elements = [
|
||||||
|
262
modules/ui_chat.py
Normal file
262
modules/ui_chat.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
import json
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from modules import chat, shared, ui, utils
|
||||||
|
from modules.html_generator import chat_html_wrapper
|
||||||
|
from modules.text_generation import stop_everything_event
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
|
||||||
|
shared.gradio.update({
|
||||||
|
'interface_state': gr.State({k: None for k in shared.input_elements}),
|
||||||
|
'Chat input': gr.State(),
|
||||||
|
'dummy': gr.State(),
|
||||||
|
'history': gr.State({'internal': [], 'visible': []}),
|
||||||
|
})
|
||||||
|
|
||||||
|
with gr.Tab('Text generation', elem_id='main'):
|
||||||
|
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': []}, shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat'))
|
||||||
|
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
|
||||||
|
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary')
|
||||||
|
shared.gradio['Continue'] = gr.Button('Continue')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||||
|
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||||
|
shared.gradio['Remove last'] = gr.Button('Remove last', elem_classes=['button_nowrap'])
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||||
|
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||||
|
shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
|
||||||
|
shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['Clear history'] = gr.Button('Clear history')
|
||||||
|
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False)
|
||||||
|
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['start_with'] = gr.Textbox(label='Start reply with', placeholder='Sure thing!', value=shared.settings['start_with'])
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'], value=shared.settings['mode'] if shared.settings['mode'] in ['chat', 'instruct', 'chat-instruct'] else 'chat', label='Mode', info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under "Chat settings" must match the current model.')
|
||||||
|
shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct')
|
||||||
|
|
||||||
|
with gr.Tab('Chat settings', elem_id='chat-settings'):
|
||||||
|
with gr.Tab("Character"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=8):
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['character_menu'] = gr.Dropdown(value='None', choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button')
|
||||||
|
shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button')
|
||||||
|
shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button')
|
||||||
|
|
||||||
|
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
|
||||||
|
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
|
||||||
|
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context', elem_classes=['add_scrollbar'])
|
||||||
|
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting', elem_classes=['add_scrollbar'])
|
||||||
|
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
|
||||||
|
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None)
|
||||||
|
|
||||||
|
with gr.Tab("Instruction template"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button')
|
||||||
|
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button')
|
||||||
|
shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button')
|
||||||
|
|
||||||
|
shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string')
|
||||||
|
shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string')
|
||||||
|
shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context')
|
||||||
|
shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.')
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.', elem_classes=['add_scrollbar'])
|
||||||
|
|
||||||
|
with gr.Tab('Chat history'):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['save_chat_history'] = gr.Button(value='Save history')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['load_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'], label="Upload History JSON")
|
||||||
|
|
||||||
|
with gr.Tab('Upload character'):
|
||||||
|
with gr.Tab('YAML or JSON'):
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json', '.yaml'], label='JSON or YAML File')
|
||||||
|
shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)')
|
||||||
|
|
||||||
|
shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False)
|
||||||
|
|
||||||
|
with gr.Tab('TavernAI PNG'):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['upload_img_tavern'] = gr.Image(type='pil', label='TavernAI PNG File', elem_id="upload_img_tavern")
|
||||||
|
shared.gradio['tavern_json'] = gr.State()
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['tavern_name'] = gr.Textbox(value='', lines=1, label='Name', interactive=False)
|
||||||
|
shared.gradio['tavern_desc'] = gr.Textbox(value='', lines=4, max_lines=4, label='Description', interactive=False)
|
||||||
|
|
||||||
|
shared.gradio['Submit tavern character'] = gr.Button(value='Submit', interactive=False)
|
||||||
|
|
||||||
|
|
||||||
|
def create_event_handlers():
|
||||||
|
gen_events = []
|
||||||
|
|
||||||
|
shared.input_params = gradio('Chat input', 'start_with', 'interface_state')
|
||||||
|
clear_arr = gradio('Clear history-confirm', 'Clear history', 'Clear history-cancel')
|
||||||
|
shared.reload_inputs = gradio('history', 'name1', 'name2', 'mode', 'chat_style')
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then(
|
||||||
|
chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then(
|
||||||
|
chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
partial(chat.generate_chat_reply_wrapper, regenerate=True), shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
partial(chat.generate_chat_reply_wrapper, _continue=True), shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Impersonate'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then(
|
||||||
|
chat.impersonate_wrapper, shared.input_params, gradio('textbox'), show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Replace last reply'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.replace_last_reply, gradio('textbox', 'interface_state'), gradio('history')).then(
|
||||||
|
lambda: '', None, gradio('textbox'), show_progress=False).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||||
|
|
||||||
|
shared.gradio['Send dummy message'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.send_dummy_message, gradio('textbox', 'interface_state'), gradio('history')).then(
|
||||||
|
lambda: '', None, gradio('textbox'), show_progress=False).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||||
|
|
||||||
|
shared.gradio['Send dummy reply'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.send_dummy_reply, gradio('textbox', 'interface_state'), gradio('history')).then(
|
||||||
|
lambda: '', None, gradio('textbox'), show_progress=False).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||||
|
|
||||||
|
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||||
|
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||||
|
shared.gradio['Clear history-confirm'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||||
|
chat.clear_chat_log, gradio('interface_state'), gradio('history')).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||||
|
|
||||||
|
shared.gradio['Remove last'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.remove_last_message, gradio('history'), gradio('textbox', 'history'), show_progress=False).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||||
|
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
|
||||||
|
|
||||||
|
shared.gradio['character_menu'].change(
|
||||||
|
partial(chat.load_character, instruct=False), gradio('character_menu', 'name1', 'name2'), gradio('name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy')).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.load_persistent_history, gradio('interface_state'), gradio('history')).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(
|
||||||
|
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||||
|
|
||||||
|
shared.gradio['mode'].change(
|
||||||
|
lambda x: gr.update(visible=x != 'instruct'), gradio('mode'), gradio('chat_style'), show_progress=False).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||||
|
|
||||||
|
shared.gradio['chat_style'].change(chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||||
|
shared.gradio['instruction_template'].change(
|
||||||
|
partial(chat.load_character, instruct=True), gradio('instruction_template', 'name1_instruct', 'name2_instruct'), gradio('name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template'))
|
||||||
|
|
||||||
|
shared.gradio['load_chat_history'].upload(
|
||||||
|
chat.load_history, gradio('load_chat_history', 'history'), gradio('history')).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||||
|
None, None, None, _js='() => {alert("The history has been loaded.")}')
|
||||||
|
|
||||||
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, gradio('history'), gradio('textbox'), show_progress=False)
|
||||||
|
|
||||||
|
# Save/delete a character
|
||||||
|
shared.gradio['save_character'].click(
|
||||||
|
lambda x: x, gradio('name2'), gradio('save_character_filename')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('character_saver'))
|
||||||
|
|
||||||
|
shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'))
|
||||||
|
|
||||||
|
shared.gradio['save_template'].click(
|
||||||
|
lambda: 'My Template.yaml', None, gradio('save_filename')).then(
|
||||||
|
lambda: 'characters/instruction-following/', None, gradio('save_root')).then(
|
||||||
|
chat.generate_instruction_template_yaml, gradio('name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template'), gradio('save_contents')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('file_saver'))
|
||||||
|
|
||||||
|
shared.gradio['delete_template'].click(
|
||||||
|
lambda x: f'{x}.yaml', gradio('instruction_template'), gradio('delete_filename')).then(
|
||||||
|
lambda: 'characters/instruction-following/', None, gradio('delete_root')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||||
|
|
||||||
|
shared.gradio['save_chat_history'].click(
|
||||||
|
lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then(
|
||||||
|
None, gradio('temporary_text', 'character_menu', 'mode'), None, _js=f"(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}")
|
||||||
|
|
||||||
|
shared.gradio['Submit character'].click(
|
||||||
|
chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu')).then(
|
||||||
|
None, None, None, _js='() => {alert("The character has been loaded.")}')
|
||||||
|
|
||||||
|
shared.gradio['Submit tavern character'].click(
|
||||||
|
chat.upload_tavern_character, gradio('upload_img_tavern', 'tavern_json'), gradio('character_menu')).then(
|
||||||
|
None, None, None, _js='() => {alert("The character has been loaded.")}')
|
||||||
|
|
||||||
|
shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character'))
|
||||||
|
shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character'))
|
||||||
|
shared.gradio['upload_img_tavern'].upload(chat.check_tavern_character, gradio('upload_img_tavern'), gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False)
|
||||||
|
shared.gradio['upload_img_tavern'].clear(lambda: (None, None, None, gr.update(interactive=False)), None, gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False)
|
||||||
|
shared.gradio['your_picture'].change(
|
||||||
|
chat.upload_your_profile_picture, gradio('your_picture'), None).then(
|
||||||
|
partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, gradio('display'))
|
94
modules/ui_default.py
Normal file
94
modules/ui_default.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared, ui, utils
|
||||||
|
from modules.prompts import count_tokens, load_prompt
|
||||||
|
from modules.text_generation import (
|
||||||
|
generate_reply_wrapper,
|
||||||
|
stop_everything_event
|
||||||
|
)
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
default_text = load_prompt(shared.settings['prompt'])
|
||||||
|
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
|
shared.gradio['last_input'] = gr.State('')
|
||||||
|
|
||||||
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes=['textbox_default', 'add_scrollbar'], lines=27, label='Input')
|
||||||
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['Generate'] = gr.Button('Generate', variant='primary')
|
||||||
|
shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
|
||||||
|
shared.gradio['Continue'] = gr.Button('Continue')
|
||||||
|
shared.gradio['count_tokens'] = gr.Button('Count tokens')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button')
|
||||||
|
shared.gradio['save_prompt'] = gr.Button('💾', elem_classes='refresh-button')
|
||||||
|
shared.gradio['delete_prompt'] = gr.Button('🗑️', elem_classes='refresh-button')
|
||||||
|
|
||||||
|
shared.gradio['status'] = gr.Markdown('')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Tab('Raw'):
|
||||||
|
shared.gradio['output_textbox'] = gr.Textbox(lines=27, label='Output', elem_classes=['textbox_default_output', 'add_scrollbar'])
|
||||||
|
|
||||||
|
with gr.Tab('Markdown'):
|
||||||
|
shared.gradio['markdown_render'] = gr.Button('Render')
|
||||||
|
shared.gradio['markdown'] = gr.Markdown()
|
||||||
|
|
||||||
|
with gr.Tab('HTML'):
|
||||||
|
shared.gradio['html'] = gr.HTML()
|
||||||
|
|
||||||
|
|
||||||
|
def create_event_handlers():
|
||||||
|
gen_events = []
|
||||||
|
shared.input_params = gradio('textbox', 'interface_state')
|
||||||
|
output_params = gradio('output_textbox', 'html')
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
|
lambda x: x, gradio('textbox'), gradio('last_input')).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
lambda x: x, gradio('textbox'), gradio('last_input')).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['markdown_render'].click(lambda x: x, gradio('output_textbox'), gradio('markdown'), queue=False)
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
generate_reply_wrapper, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
|
shared.gradio['prompt_menu'].change(load_prompt, gradio('prompt_menu'), gradio('textbox'), show_progress=False)
|
||||||
|
shared.gradio['save_prompt'].click(
|
||||||
|
lambda x: x, gradio('textbox'), gradio('save_contents')).then(
|
||||||
|
lambda: 'prompts/', None, gradio('save_root')).then(
|
||||||
|
lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('file_saver'))
|
||||||
|
|
||||||
|
shared.gradio['delete_prompt'].click(
|
||||||
|
lambda: 'prompts/', None, gradio('delete_root')).then(
|
||||||
|
lambda x: x + '.txt', gradio('prompt_menu'), gradio('delete_filename')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||||
|
|
||||||
|
shared.gradio['count_tokens'].click(count_tokens, gradio('textbox'), gradio('status'), show_progress=False)
|
108
modules/ui_file_saving.py
Normal file
108
modules/ui_file_saving.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import chat, presets, shared, ui, utils
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
|
||||||
|
# Text file saver
|
||||||
|
with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']:
|
||||||
|
shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name')
|
||||||
|
shared.gradio['save_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False)
|
||||||
|
shared.gradio['save_contents'] = gr.Textbox(lines=10, label='File contents')
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['save_confirm'] = gr.Button('Save', elem_classes="small-button")
|
||||||
|
shared.gradio['save_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
|
|
||||||
|
# Text file deleter
|
||||||
|
with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['file_deleter']:
|
||||||
|
shared.gradio['delete_filename'] = gr.Textbox(lines=1, label='File name')
|
||||||
|
shared.gradio['delete_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False)
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['delete_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop')
|
||||||
|
shared.gradio['delete_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
|
|
||||||
|
# Character saver/deleter
|
||||||
|
if shared.is_chat():
|
||||||
|
with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']:
|
||||||
|
shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info='The character will be saved to your characters/ folder with this base filename.')
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button")
|
||||||
|
shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
|
|
||||||
|
with gr.Box(visible=False, elem_classes='file-saver') as shared.gradio['character_deleter']:
|
||||||
|
gr.Markdown('Confirm the character deletion?')
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop')
|
||||||
|
shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
|
|
||||||
|
|
||||||
|
def create_event_handlers():
|
||||||
|
shared.gradio['save_confirm'].click(
|
||||||
|
lambda x, y, z: utils.save_file(x + y, z), gradio('save_root', 'save_filename', 'save_contents'), None).then(
|
||||||
|
lambda: gr.update(visible=False), None, gradio('file_saver'))
|
||||||
|
|
||||||
|
shared.gradio['delete_confirm'].click(
|
||||||
|
lambda x, y: utils.delete_file(x + y), gradio('delete_root', 'delete_filename'), None).then(
|
||||||
|
lambda: gr.update(visible=False), None, gradio('file_deleter'))
|
||||||
|
|
||||||
|
shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_deleter'))
|
||||||
|
shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_saver'))
|
||||||
|
if shared.is_chat():
|
||||||
|
shared.gradio['save_character_confirm'].click(
|
||||||
|
chat.save_character, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), None).then(
|
||||||
|
lambda: gr.update(visible=False), None, gradio('character_saver'))
|
||||||
|
|
||||||
|
shared.gradio['delete_character_confirm'].click(
|
||||||
|
chat.delete_character, gradio('character_menu'), None).then(
|
||||||
|
lambda: gr.update(visible=False), None, gradio('character_deleter')).then(
|
||||||
|
lambda: gr.update(choices=utils.get_available_characters()), None, gradio('character_menu'))
|
||||||
|
|
||||||
|
shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'))
|
||||||
|
shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'))
|
||||||
|
|
||||||
|
shared.gradio['save_preset'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
presets.generate_preset_yaml, gradio('interface_state'), gradio('save_contents')).then(
|
||||||
|
lambda: 'presets/', None, gradio('save_root')).then(
|
||||||
|
lambda: 'My Preset.yaml', None, gradio('save_filename')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('file_saver'))
|
||||||
|
|
||||||
|
shared.gradio['delete_preset'].click(
|
||||||
|
lambda x: f'{x}.yaml', gradio('preset_menu'), gradio('delete_filename')).then(
|
||||||
|
lambda: 'presets/', None, gradio('delete_root')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||||
|
|
||||||
|
if not shared.args.multi_user:
|
||||||
|
shared.gradio['save_session'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('temporary_text')).then(
|
||||||
|
None, gradio('temporary_text'), None, _js=f"(contents) => {{{ui.save_files_js}; saveSession(contents, \"{shared.get_mode()}\")}}")
|
||||||
|
|
||||||
|
if shared.is_chat():
|
||||||
|
shared.gradio['load_session'].upload(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
load_session, gradio('load_session', 'interface_state'), gradio('interface_state')).then(
|
||||||
|
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
||||||
|
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
|
||||||
|
None, None, None, _js='() => {alert("The session has been loaded.")}')
|
||||||
|
else:
|
||||||
|
shared.gradio['load_session'].upload(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
load_session, gradio('load_session', 'interface_state'), gradio('interface_state')).then(
|
||||||
|
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
||||||
|
None, None, None, _js='() => {alert("The session has been loaded.")}')
|
||||||
|
|
||||||
|
|
||||||
|
def load_session(file, state):
|
||||||
|
decoded_file = file if type(file) == str else file.decode('utf-8')
|
||||||
|
data = json.loads(decoded_file)
|
||||||
|
|
||||||
|
if shared.is_chat() and 'character_menu' in data and state.get('character_menu') != data.get('character_menu'):
|
||||||
|
shared.session_is_loading = True
|
||||||
|
|
||||||
|
state.update(data)
|
||||||
|
return state
|
229
modules/ui_model_menu.py
Normal file
229
modules/ui_model_menu.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
import importlib
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import loaders, shared, ui, utils
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
from modules.LoRA import add_lora_to_model
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
|
from modules.models_settings import (
|
||||||
|
apply_model_settings_to_state,
|
||||||
|
save_model_settings,
|
||||||
|
update_model_parameters
|
||||||
|
)
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
# Finding the default values for the GPU and CPU memories
|
||||||
|
total_mem = []
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
|
||||||
|
|
||||||
|
default_gpu_mem = []
|
||||||
|
if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
|
||||||
|
for i in shared.args.gpu_memory:
|
||||||
|
if 'mib' in i.lower():
|
||||||
|
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
|
||||||
|
else:
|
||||||
|
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000)
|
||||||
|
while len(default_gpu_mem) < len(total_mem):
|
||||||
|
default_gpu_mem.append(0)
|
||||||
|
|
||||||
|
total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024))
|
||||||
|
if shared.args.cpu_memory is not None:
|
||||||
|
default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
|
||||||
|
else:
|
||||||
|
default_cpu_mem = 0
|
||||||
|
|
||||||
|
with gr.Tab("Model", elem_id="model-tab"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=shared.model_name, label='Model', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button')
|
||||||
|
shared.gradio['load_model'] = gr.Button("Load", visible=not shared.settings['autoload_model'], elem_classes='refresh-button')
|
||||||
|
shared.gradio['unload_model'] = gr.Button("Unload", elem_classes='refresh-button')
|
||||||
|
shared.gradio['reload_model'] = gr.Button("Reload", elem_classes='refresh-button')
|
||||||
|
shared.gradio['save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
|
||||||
|
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value=None)
|
||||||
|
with gr.Box():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
for i in range(len(total_mem)):
|
||||||
|
shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i])
|
||||||
|
|
||||||
|
shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
|
||||||
|
shared.gradio['transformers_info'] = gr.Markdown('load-in-4bit params:')
|
||||||
|
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype)
|
||||||
|
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type)
|
||||||
|
|
||||||
|
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers)
|
||||||
|
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx)
|
||||||
|
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads)
|
||||||
|
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
|
||||||
|
shared.gradio['n_gqa'] = gr.Slider(minimum=0, maximum=16, step=1, label="n_gqa", value=shared.args.n_gqa, info='grouped-query attention. Must be 8 for llama-2 70b.')
|
||||||
|
shared.gradio['rms_norm_eps'] = gr.Slider(minimum=0, maximum=1e-5, step=1e-6, label="rms_norm_eps", value=shared.args.n_gqa, info='5e-6 is a good value for llama-2 models.')
|
||||||
|
|
||||||
|
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
|
||||||
|
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")
|
||||||
|
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
|
||||||
|
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
|
||||||
|
shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.')
|
||||||
|
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||||
|
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=0, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len)
|
||||||
|
shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb)
|
||||||
|
shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=32, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Scaling is not identical to embedding compression. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
|
||||||
|
shared.gradio['no_inject_fused_attention'] = gr.Checkbox(label="no_inject_fused_attention", value=shared.args.no_inject_fused_attention, info='Disable fused attention. Fused attention improves inference performance but uses more VRAM. Disable if running low on VRAM.')
|
||||||
|
shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.')
|
||||||
|
shared.gradio['no_use_cuda_fp16'] = gr.Checkbox(label="no_use_cuda_fp16", value=shared.args.no_use_cuda_fp16, info='This can make models faster on some systems.')
|
||||||
|
shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
|
||||||
|
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
|
||||||
|
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
||||||
|
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
|
||||||
|
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
||||||
|
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
||||||
|
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
||||||
|
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
|
||||||
|
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
||||||
|
shared.gradio['low_vram'] = gr.Checkbox(label="low-vram", value=shared.args.low_vram)
|
||||||
|
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
||||||
|
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
|
||||||
|
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
||||||
|
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
|
||||||
|
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
|
||||||
|
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')
|
||||||
|
shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF is a wrapper that lets you use llama.cpp like a Transformers model, which means it can use the Transformers samplers. To use it, make sure to first download oobabooga/llama-tokenizer under "Download custom model or LoRA".')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.')
|
||||||
|
|
||||||
|
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
|
||||||
|
shared.gradio['download_model_button'] = gr.Button("Download")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
||||||
|
|
||||||
|
|
||||||
|
def create_event_handlers():
|
||||||
|
shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()))
|
||||||
|
|
||||||
|
# In this event handler, the interface state is read and updated
|
||||||
|
# with the model defaults (if any), and then the model is loaded
|
||||||
|
# unless "autoload_model" is unchecked
|
||||||
|
shared.gradio['model_menu'].change(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then(
|
||||||
|
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
||||||
|
update_model_parameters, gradio('interface_state'), None).then(
|
||||||
|
load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['load_model'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
update_model_parameters, gradio('interface_state'), None).then(
|
||||||
|
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['unload_model'].click(
|
||||||
|
unload_model, None, None).then(
|
||||||
|
lambda: "Model unloaded", None, gradio('model_status'))
|
||||||
|
|
||||||
|
shared.gradio['reload_model'].click(
|
||||||
|
unload_model, None, None).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
update_model_parameters, gradio('interface_state'), None).then(
|
||||||
|
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['save_model_settings'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
|
||||||
|
shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu'), gradio('model_status'), show_progress=True)
|
||||||
|
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model'))
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_wrapper(selected_model, loader, autoload=False):
|
||||||
|
if not autoload:
|
||||||
|
yield f"The settings for {selected_model} have been updated.\nClick on \"Load\" to load it."
|
||||||
|
return
|
||||||
|
|
||||||
|
if selected_model == 'None':
|
||||||
|
yield "No model selected"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
yield f"Loading {selected_model}..."
|
||||||
|
shared.model_name = selected_model
|
||||||
|
unload_model()
|
||||||
|
if selected_model != '':
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name, loader)
|
||||||
|
|
||||||
|
if shared.model is not None:
|
||||||
|
yield f"Successfully loaded {selected_model}"
|
||||||
|
else:
|
||||||
|
yield f"Failed to load {selected_model}."
|
||||||
|
except:
|
||||||
|
exc = traceback.format_exc()
|
||||||
|
logger.error('Failed to load the model.')
|
||||||
|
print(exc)
|
||||||
|
yield exc.replace('\n', '\n\n')
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora_wrapper(selected_loras):
|
||||||
|
yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
|
||||||
|
add_lora_to_model(selected_loras)
|
||||||
|
yield ("Successfuly applied the LoRAs")
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_wrapper(repo_id, progress=gr.Progress()):
|
||||||
|
try:
|
||||||
|
downloader_module = importlib.import_module("download-model")
|
||||||
|
downloader = downloader_module.ModelDownloader()
|
||||||
|
repo_id_parts = repo_id.split(":")
|
||||||
|
model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id
|
||||||
|
branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
|
||||||
|
check = False
|
||||||
|
|
||||||
|
progress(0.0)
|
||||||
|
yield ("Cleaning up the model/branch names")
|
||||||
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
|
|
||||||
|
yield ("Getting the download links from Hugging Face")
|
||||||
|
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||||
|
|
||||||
|
yield ("Getting the output folder")
|
||||||
|
base_folder = shared.args.lora_dir if is_lora else shared.args.model_dir
|
||||||
|
output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=base_folder)
|
||||||
|
|
||||||
|
if check:
|
||||||
|
progress(0.5)
|
||||||
|
yield ("Checking previously downloaded files")
|
||||||
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
|
progress(1.0)
|
||||||
|
else:
|
||||||
|
yield (f"Downloading files to {output_folder}")
|
||||||
|
downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1)
|
||||||
|
yield ("Done!")
|
||||||
|
except:
|
||||||
|
progress(1.0)
|
||||||
|
yield traceback.format_exc().replace('\n', '\n\n')
|
98
modules/ui_notebook.py
Normal file
98
modules/ui_notebook.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared, ui, utils
|
||||||
|
from modules.prompts import count_tokens, load_prompt
|
||||||
|
from modules.text_generation import (
|
||||||
|
generate_reply_wrapper,
|
||||||
|
stop_everything_event
|
||||||
|
)
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
default_text = load_prompt(shared.settings['prompt'])
|
||||||
|
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
|
shared.gradio['last_input'] = gr.State('')
|
||||||
|
|
||||||
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=4):
|
||||||
|
with gr.Tab('Raw'):
|
||||||
|
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes=['textbox', 'add_scrollbar'], lines=27)
|
||||||
|
|
||||||
|
with gr.Tab('Markdown'):
|
||||||
|
shared.gradio['markdown_render'] = gr.Button('Render')
|
||||||
|
shared.gradio['markdown'] = gr.Markdown()
|
||||||
|
|
||||||
|
with gr.Tab('HTML'):
|
||||||
|
shared.gradio['html'] = gr.HTML()
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
|
||||||
|
shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button", elem_id='stop')
|
||||||
|
shared.gradio['Undo'] = gr.Button('Undo', elem_classes="small-button")
|
||||||
|
shared.gradio['Regenerate'] = gr.Button('Regenerate', elem_classes="small-button")
|
||||||
|
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
gr.HTML('<div style="padding-bottom: 13px"></div>')
|
||||||
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, ['refresh-button', 'refresh-button-small'])
|
||||||
|
shared.gradio['save_prompt'] = gr.Button('💾', elem_classes=['refresh-button', 'refresh-button-small'])
|
||||||
|
shared.gradio['delete_prompt'] = gr.Button('🗑️', elem_classes=['refresh-button', 'refresh-button-small'])
|
||||||
|
|
||||||
|
shared.gradio['count_tokens'] = gr.Button('Count tokens')
|
||||||
|
shared.gradio['status'] = gr.Markdown('')
|
||||||
|
|
||||||
|
|
||||||
|
def create_event_handlers():
|
||||||
|
gen_events = []
|
||||||
|
|
||||||
|
shared.input_params = gradio('textbox', 'interface_state')
|
||||||
|
output_params = gradio('textbox', 'html')
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
|
lambda x: x, gradio('textbox'), gradio('last_input')).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
lambda x: x, gradio('textbox'), gradio('last_input')).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Undo'].click(lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False)
|
||||||
|
shared.gradio['markdown_render'].click(lambda x: x, gradio('textbox'), gradio('markdown'), queue=False)
|
||||||
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
|
lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: None, None, None, _js=f"() => {{{ui.audio_notification_js}}}")
|
||||||
|
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
|
shared.gradio['prompt_menu'].change(load_prompt, gradio('prompt_menu'), gradio('textbox'), show_progress=False)
|
||||||
|
shared.gradio['save_prompt'].click(
|
||||||
|
lambda x: x, gradio('textbox'), gradio('save_contents')).then(
|
||||||
|
lambda: 'prompts/', None, gradio('save_root')).then(
|
||||||
|
lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('file_saver'))
|
||||||
|
|
||||||
|
shared.gradio['delete_prompt'].click(
|
||||||
|
lambda: 'prompts/', None, gradio('delete_root')).then(
|
||||||
|
lambda x: x + '.txt', gradio('prompt_menu'), gradio('delete_filename')).then(
|
||||||
|
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||||
|
|
||||||
|
shared.gradio['count_tokens'].click(count_tokens, gradio('textbox'), gradio('status'), show_progress=False)
|
143
modules/ui_parameters.py
Normal file
143
modules/ui_parameters.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import loaders, presets, shared, ui, utils
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui(default_preset):
|
||||||
|
generate_params = presets.load_preset(default_preset)
|
||||||
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=default_preset, label='Generation parameters preset', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button')
|
||||||
|
shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button')
|
||||||
|
shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['filter_by_loader'] = gr.Dropdown(label="Filter by loader", choices=["All", "Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value="All", elem_classes='slim-dropdown')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Box():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
||||||
|
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
||||||
|
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
||||||
|
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
||||||
|
shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff')
|
||||||
|
shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff')
|
||||||
|
shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')
|
||||||
|
shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
|
||||||
|
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
|
||||||
|
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
|
||||||
|
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
||||||
|
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length')
|
||||||
|
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||||
|
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||||
|
|
||||||
|
with gr.Accordion("Learn more", open=False):
|
||||||
|
gr.Markdown("""
|
||||||
|
|
||||||
|
For a technical description of the parameters, the [transformers documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) is a good reference.
|
||||||
|
|
||||||
|
The best presets, according to the [Preset Arena](https://github.com/oobabooga/oobabooga.github.io/blob/main/arena/results.md) experiment, are:
|
||||||
|
|
||||||
|
* Instruction following:
|
||||||
|
1) Divine Intellect
|
||||||
|
2) Big O
|
||||||
|
3) simple-1
|
||||||
|
4) Space Alien
|
||||||
|
5) StarChat
|
||||||
|
6) Titanic
|
||||||
|
7) tfs-with-top-a
|
||||||
|
8) Asterism
|
||||||
|
9) Contrastive Search
|
||||||
|
|
||||||
|
* Chat:
|
||||||
|
1) Midnight Enigma
|
||||||
|
2) Yara
|
||||||
|
3) Shortwave
|
||||||
|
|
||||||
|
### Temperature
|
||||||
|
Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.
|
||||||
|
### top_p
|
||||||
|
If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.
|
||||||
|
### top_k
|
||||||
|
Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.
|
||||||
|
### typical_p
|
||||||
|
If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
|
||||||
|
### epsilon_cutoff
|
||||||
|
In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. Should be used with top_p, top_k, and eta_cutoff set to 0.
|
||||||
|
### eta_cutoff
|
||||||
|
In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0.
|
||||||
|
### repetition_penalty
|
||||||
|
Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
|
||||||
|
### repetition_penalty_range
|
||||||
|
The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
|
||||||
|
### encoder_repetition_penalty
|
||||||
|
Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.
|
||||||
|
### no_repeat_ngram_size
|
||||||
|
If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.
|
||||||
|
### min_length
|
||||||
|
Minimum generation length in tokens.
|
||||||
|
### penalty_alpha
|
||||||
|
Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.
|
||||||
|
|
||||||
|
""", elem_classes="markdown")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
create_chat_settings_menus()
|
||||||
|
with gr.Box():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.')
|
||||||
|
shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt')
|
||||||
|
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||||
|
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||||
|
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.')
|
||||||
|
|
||||||
|
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
|
||||||
|
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||||
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
|
|
||||||
|
with gr.Box():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||||
|
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
||||||
|
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||||
|
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||||
|
|
||||||
|
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||||
|
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||||
|
|
||||||
|
|
||||||
|
def create_event_handlers():
|
||||||
|
shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False)
|
||||||
|
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_settings_menus():
|
||||||
|
if not shared.is_chat():
|
||||||
|
return
|
||||||
|
|
||||||
|
with gr.Box():
|
||||||
|
gr.Markdown("Chat parameters")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
|
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations.')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
|
71
modules/ui_session.py
Normal file
71
modules/ui_session.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared, ui, utils
|
||||||
|
from modules.github import clone_or_pull_repository
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
with gr.Tab("Session", elem_id="session-tab"):
|
||||||
|
modes = ["default", "notebook", "chat"]
|
||||||
|
current_mode = "default"
|
||||||
|
for mode in modes[1:]:
|
||||||
|
if getattr(shared.args, mode):
|
||||||
|
current_mode = mode
|
||||||
|
break
|
||||||
|
|
||||||
|
cmd_list = vars(shared.args)
|
||||||
|
bool_list = sorted([k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes + ui.list_model_elements()])
|
||||||
|
bool_active = [k for k in bool_list if vars(shared.args)[k]]
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode", elem_classes='slim-dropdown')
|
||||||
|
shared.gradio['reset_interface'] = gr.Button("Apply and restart", elem_classes="small-button", variant="primary")
|
||||||
|
shared.gradio['toggle_dark_mode'] = gr.Button('Toggle 💡', elem_classes="small-button")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=utils.get_available_extensions(), value=shared.args.extensions, label="Available extensions", info='Note that some of these extensions may require manually installing Python requirements through the command: pip install -r extensions/extension_name/requirements.txt', elem_classes='checkboxgroup-table')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags", elem_classes='checkboxgroup-table')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
if not shared.args.multi_user:
|
||||||
|
shared.gradio['save_session'] = gr.Button('Save session', elem_id="save_session")
|
||||||
|
shared.gradio['load_session'] = gr.File(type='binary', file_types=['.json'], label="Upload Session JSON")
|
||||||
|
|
||||||
|
extension_name = gr.Textbox(lines=1, label='Install or update an extension', info='Enter the GitHub URL below and press Enter. For a list of extensions, see: https://github.com/oobabooga/text-generation-webui-extensions ⚠️ WARNING ⚠️ : extensions can execute arbitrary code. Make sure to inspect their source code before activating them.')
|
||||||
|
extension_status = gr.Markdown()
|
||||||
|
|
||||||
|
extension_name.submit(
|
||||||
|
clone_or_pull_repository, extension_name, extension_status, show_progress=False).then(
|
||||||
|
lambda: gr.update(choices=utils.get_available_extensions(), value=shared.args.extensions), None, gradio('extensions_menu'))
|
||||||
|
|
||||||
|
# Reset interface event
|
||||||
|
shared.gradio['reset_interface'].click(
|
||||||
|
set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then(
|
||||||
|
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;padding-top:20%;margin:0;height:100vh;color:lightgray;text-align:center;background:var(--body-background-fill)">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||||
|
|
||||||
|
shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}')
|
||||||
|
|
||||||
|
|
||||||
|
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||||
|
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||||
|
cmd_list = vars(shared.args)
|
||||||
|
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||||
|
|
||||||
|
shared.args.extensions = extensions
|
||||||
|
for k in modes[1:]:
|
||||||
|
setattr(shared.args, k, False)
|
||||||
|
if interface_mode != "default":
|
||||||
|
setattr(shared.args, interface_mode, True)
|
||||||
|
for k in bool_list:
|
||||||
|
setattr(shared.args, k, False)
|
||||||
|
for k in bool_active:
|
||||||
|
setattr(shared.args, k, True)
|
||||||
|
|
||||||
|
shared.need_restart = True
|
Loading…
Reference in New Issue
Block a user