Merge pull request #5011 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2023-12-20 02:38:25 -03:00 committed by GitHub
commit c1f78dbd0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 120 additions and 171 deletions

View File

@ -6,27 +6,13 @@ import time
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
from modules import chat, shared, ui_chat from modules import chat, shared, ui_chat
from modules.logging_colors import logger
from modules.ui import create_refresh_button from modules.ui import create_refresh_button
from modules.utils import gradio from modules.utils import gradio
try:
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
except ModuleNotFoundError:
logger.error(
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
"\n"
"\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n"
"\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n"
"\n"
"If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"."
)
raise
os.environ["COQUI_TOS_AGREED"] = "1" os.environ["COQUI_TOS_AGREED"] = "1"
params = { params = {

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import json import json
import logging
import os import os
import traceback import traceback
from threading import Thread from threading import Thread
@ -367,6 +368,7 @@ def run_server():
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key: if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n') logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
logging.getLogger("uvicorn.error").propagate = False
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)

View File

@ -126,7 +126,7 @@ def load_quantized(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name) pt_path = find_quantized_model_file(model_name)
if not pt_path: if not pt_path:
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...") logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
exit() exit()
else: else:
logger.info(f"Found the following quantized model: {pt_path}") logger.info(f"Found the following quantized model: {pt_path}")

View File

@ -138,7 +138,7 @@ def add_lora_transformers(lora_names):
# Add a LoRA when another LoRA is already present # Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys(): if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
logger.info(f"Adding the LoRA(s) named {added_set} to the model...") logger.info(f"Adding the LoRA(s) named {added_set} to the model")
for lora in added_set: for lora in added_set:
shared.model.load_adapter(get_lora_path(lora), lora) shared.model.load_adapter(get_lora_path(lora), lora)

View File

@ -95,7 +95,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
else: else:
renderer = chat_renderer renderer = chat_renderer
if state['context'].strip() != '': if state['context'].strip() != '':
messages.append({"role": "system", "content": state['context']}) context = replace_character_names(state['context'], state['name1'], state['name2'])
messages.append({"role": "system", "content": context})
insert_pos = len(messages) insert_pos = len(messages)
for user_msg, assistant_msg in reversed(history): for user_msg, assistant_msg in reversed(history):

View File

@ -31,9 +31,14 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
if name != 'api': if name != 'api':
logger.info(f'Loading the extension "{name}"...') logger.info(f'Loading the extension "{name}"')
try: try:
exec(f"import extensions.{name}.script") try:
exec(f"import extensions.{name}.script")
except ModuleNotFoundError:
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.")
raise
extension = getattr(extensions, name).script extension = getattr(extensions, name).script
apply_settings(extension, name) apply_settings(extension, name)
if extension not in setup_called and hasattr(extension, "setup"): if extension not in setup_called and hasattr(extension, "setup"):

View File

@ -1,117 +1,67 @@
# Copied from https://stackoverflow.com/a/1336640
import logging import logging
import platform
logging.basicConfig(
format='%(asctime)s %(levelname)s:%(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
def add_coloring_to_emit_windows(fn):
# add methods we need to the class
def _out_handle(self):
import ctypes
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
out_handle = property(_out_handle)
def _set_color(self, code):
import ctypes
# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
setattr(logging.StreamHandler, '_set_color', _set_color)
def new(*args):
FOREGROUND_BLUE = 0x0001 # text color contains blue.
FOREGROUND_GREEN = 0x0002 # text color contains green.
FOREGROUND_RED = 0x0004 # text color contains red.
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
# winbase.h
# STD_INPUT_HANDLE = -10
# STD_OUTPUT_HANDLE = -11
# STD_ERROR_HANDLE = -12
# wincon.h
# FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
# FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
# FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
# BACKGROUND_BLACK = 0x0000
# BACKGROUND_BLUE = 0x0010
# BACKGROUND_GREEN = 0x0020
# BACKGROUND_CYAN = 0x0030
# BACKGROUND_RED = 0x0040
# BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
# BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
levelno = args[1].levelno
if (levelno >= 50):
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
elif (levelno >= 40):
color = FOREGROUND_RED | FOREGROUND_INTENSITY
elif (levelno >= 30):
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
elif (levelno >= 20):
color = FOREGROUND_GREEN
elif (levelno >= 10):
color = FOREGROUND_MAGENTA
else:
color = FOREGROUND_WHITE
args[0]._set_color(color)
ret = fn(*args)
args[0]._set_color(FOREGROUND_WHITE)
# print "after"
return ret
return new
def add_coloring_to_emit_ansi(fn):
# add methods we need to the class
def new(*args):
levelno = args[1].levelno
if (levelno >= 50):
color = '\x1b[31m' # red
elif (levelno >= 40):
color = '\x1b[31m' # red
elif (levelno >= 30):
color = '\x1b[33m' # yellow
elif (levelno >= 20):
color = '\x1b[32m' # green
elif (levelno >= 10):
color = '\x1b[35m' # pink
else:
color = '\x1b[0m' # normal
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
# print "after"
return fn(*args)
return new
if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
else:
# all non-Windows platforms are supporting ANSI escapes so we use them
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
# log = logging.getLogger()
# log.addFilter(log_filter())
# //hdlr = logging.StreamHandler()
# //hdlr.setFormatter(formatter())
logger = logging.getLogger('text-generation-webui') logger = logging.getLogger('text-generation-webui')
logger.setLevel(logging.DEBUG)
def setup_logging():
'''
Copied from: https://github.com/vladmandic/automatic
All credits to vladmandic.
'''
class RingBuffer(logging.StreamHandler):
def __init__(self, capacity):
super().__init__()
self.capacity = capacity
self.buffer = []
self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }')
def emit(self, record):
msg = self.format(record)
# self.buffer.append(json.loads(msg))
self.buffer.append(msg)
if len(self.buffer) > self.capacity:
self.buffer.pop(0)
def get(self):
return self.buffer
from rich.console import Console
from rich.logging import RichHandler
from rich.pretty import install as pretty_install
from rich.theme import Theme
from rich.traceback import install as traceback_install
level = logging.DEBUG
logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
"traceback.border": "black",
"traceback.border.syntax_error": "black",
"inspect.value.border": "black",
}))
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
pretty_install(console=console)
traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
while logger.hasHandlers() and len(logger.handlers) > 0:
logger.removeHandler(logger.handlers[0])
# handlers
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
rh.setLevel(level)
logger.addHandler(rh)
rb = RingBuffer(100) # 100 entries default in log ring buffer
rb.setLevel(level)
logger.addHandler(rb)
logger.buffer = rb.buffer
# overrides
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("diffusers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("lycoris").handlers = logger.handlers
setup_logging()

View File

@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers()
def load_model(model_name, loader=None): def load_model(model_name, loader=None):
logger.info(f"Loading {model_name}...") logger.info(f"Loading {model_name}")
t0 = time.time() t0 = time.time()
shared.is_seq2seq = False shared.is_seq2seq = False
@ -413,8 +413,12 @@ def ExLlamav2_HF_loader(model_name):
def HQQ_loader(model_name): def HQQ_loader(model_name):
from hqq.engine.hf import HQQModelForCausalLM try:
from hqq.core.quantize import HQQLinear, HQQBackend from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.engine.hf import HQQModelForCausalLM
except ModuleNotFoundError:
logger.error("HQQ is not installed. You can install it with:\n\npip install hqq")
return None
logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}") logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")

View File

@ -204,22 +204,26 @@ for arg in sys.argv[1:]:
if hasattr(args, arg): if hasattr(args, arg):
provided_arguments.append(arg) provided_arguments.append(arg)
# Deprecation warnings
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast'] deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
for k in deprecated_args:
if getattr(args, k):
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
# Security warnings
if args.trust_remote_code: def do_cmd_flags_warnings():
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui: # Deprecation warnings
if args.share: for k in deprecated_args:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") if getattr(args, k):
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user: # Security warnings
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') if args.trust_remote_code:
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui:
if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
def fix_loader_name(name): def fix_loader_name(name):

View File

@ -249,7 +249,7 @@ def backup_adapter(input_folder):
adapter_file = Path(f"{input_folder}/adapter_model.bin") adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file(): if adapter_file.is_file():
logger.info("Backing up existing LoRA adapter...") logger.info("Backing up existing LoRA adapter")
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime) creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']: if raw_text_file not in ['None', '']:
train_template["template_type"] = "raw_text" train_template["template_type"] = "raw_text"
logger.info("Loading raw text file dataset...") logger.info("Loading raw text file dataset")
fullpath = clean_path('training/datasets', f'{raw_text_file}') fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath) fullpath = Path(fullpath)
if fullpath.is_dir(): if fullpath.is_dir():
@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
prompt = generate_prompt(data_point) prompt = generate_prompt(data_point)
return tokenize(prompt, add_eos_token) return tokenize(prompt, add_eos_token)
logger.info("Loading JSON datasets...") logger.info("Loading JSON datasets")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30)) train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Start prepping the model itself == # == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...") logger.info("Getting model ready")
prepare_model_for_kbit_training(shared.model) prepare_model_for_kbit_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one # base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True shared.model_dirty_from_training = True
logger.info("Preparing for training...") logger.info("Preparing for training")
config = LoraConfig( config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
try: try:
logger.info("Creating LoRA model...") logger.info("Creating LoRA model")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logger.info("Loading existing LoRA data...") logger.info("Loading existing LoRA data")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True) state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
set_peft_model_state_dict(lora_model, state_dict_peft) set_peft_model_state_dict(lora_model, state_dict_peft)
except: except:
@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
json.dump(train_template, file, indent=2) json.dump(train_template, file, indent=2)
# == Main run and monitor loop == # == Main run and monitor loop ==
logger.info("Starting training...") logger.info("Starting training")
yield "Starting..." yield "Starting..."
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# Saving in the train thread might fail if an error occurs, so save here if so. # Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save: if not tracked.did_save:
logger.info("Training complete, saving...") logger.info("Training complete, saving")
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT: if WANT_INTERRUPT:

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64" exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64" exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -12,6 +12,8 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
with RequestBlocker(): with RequestBlocker():
import gradio as gr import gradio as gr
@ -54,6 +56,7 @@ from modules.models_settings import (
get_model_metadata, get_model_metadata,
update_model_parameters update_model_parameters
) )
from modules.shared import do_cmd_flags_warnings
from modules.utils import gradio from modules.utils import gradio
@ -170,6 +173,9 @@ def create_interface():
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Starting Text generation web UI")
do_cmd_flags_warnings()
# Load custom settings # Load custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -180,7 +186,7 @@ if __name__ == "__main__":
settings_file = Path('settings.json') settings_file = Path('settings.json')
if settings_file is not None: if settings_file is not None:
logger.info(f"Loading settings from {settings_file}...") logger.info(f"Loading settings from {settings_file}")
file_contents = open(settings_file, 'r', encoding='utf-8').read() file_contents = open(settings_file, 'r', encoding='utf-8').read()
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents) new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
shared.settings.update(new_settings) shared.settings.update(new_settings)