mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-01 00:39:09 +01:00
commit
c1f78dbd0f
@ -6,27 +6,13 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
from modules import chat, shared, ui_chat
|
||||
from modules.logging_colors import logger
|
||||
from modules.ui import create_refresh_button
|
||||
from modules.utils import gradio
|
||||
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
except ModuleNotFoundError:
|
||||
logger.error(
|
||||
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
|
||||
"\n"
|
||||
"\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n"
|
||||
"\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n"
|
||||
"\n"
|
||||
"If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"."
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
os.environ["COQUI_TOS_AGREED"] = "1"
|
||||
|
||||
params = {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from threading import Thread
|
||||
@ -367,6 +368,7 @@ def run_server():
|
||||
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
|
||||
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
|
||||
|
||||
logging.getLogger("uvicorn.error").propagate = False
|
||||
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
||||
|
||||
|
||||
|
@ -126,7 +126,7 @@ def load_quantized(model_name):
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
pt_path = find_quantized_model_file(model_name)
|
||||
if not pt_path:
|
||||
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||
logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
|
||||
exit()
|
||||
else:
|
||||
logger.info(f"Found the following quantized model: {pt_path}")
|
||||
|
@ -138,7 +138,7 @@ def add_lora_transformers(lora_names):
|
||||
|
||||
# Add a LoRA when another LoRA is already present
|
||||
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
|
||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model")
|
||||
for lora in added_set:
|
||||
shared.model.load_adapter(get_lora_path(lora), lora)
|
||||
|
||||
|
@ -95,7 +95,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
else:
|
||||
renderer = chat_renderer
|
||||
if state['context'].strip() != '':
|
||||
messages.append({"role": "system", "content": state['context']})
|
||||
context = replace_character_names(state['context'], state['name1'], state['name2'])
|
||||
messages.append({"role": "system", "content": context})
|
||||
|
||||
insert_pos = len(messages)
|
||||
for user_msg, assistant_msg in reversed(history):
|
||||
|
@ -31,9 +31,14 @@ def load_extensions():
|
||||
for i, name in enumerate(shared.args.extensions):
|
||||
if name in available_extensions:
|
||||
if name != 'api':
|
||||
logger.info(f'Loading the extension "{name}"...')
|
||||
logger.info(f'Loading the extension "{name}"')
|
||||
try:
|
||||
exec(f"import extensions.{name}.script")
|
||||
try:
|
||||
exec(f"import extensions.{name}.script")
|
||||
except ModuleNotFoundError:
|
||||
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.")
|
||||
raise
|
||||
|
||||
extension = getattr(extensions, name).script
|
||||
apply_settings(extension, name)
|
||||
if extension not in setup_called and hasattr(extension, "setup"):
|
||||
|
@ -1,117 +1,67 @@
|
||||
# Copied from https://stackoverflow.com/a/1336640
|
||||
|
||||
import logging
|
||||
import platform
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s %(levelname)s:%(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
|
||||
|
||||
def add_coloring_to_emit_windows(fn):
|
||||
# add methods we need to the class
|
||||
def _out_handle(self):
|
||||
import ctypes
|
||||
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
||||
out_handle = property(_out_handle)
|
||||
|
||||
def _set_color(self, code):
|
||||
import ctypes
|
||||
|
||||
# Constants from the Windows API
|
||||
self.STD_OUTPUT_HANDLE = -11
|
||||
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
||||
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
|
||||
|
||||
setattr(logging.StreamHandler, '_set_color', _set_color)
|
||||
|
||||
def new(*args):
|
||||
FOREGROUND_BLUE = 0x0001 # text color contains blue.
|
||||
FOREGROUND_GREEN = 0x0002 # text color contains green.
|
||||
FOREGROUND_RED = 0x0004 # text color contains red.
|
||||
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
|
||||
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
|
||||
# winbase.h
|
||||
# STD_INPUT_HANDLE = -10
|
||||
# STD_OUTPUT_HANDLE = -11
|
||||
# STD_ERROR_HANDLE = -12
|
||||
|
||||
# wincon.h
|
||||
# FOREGROUND_BLACK = 0x0000
|
||||
FOREGROUND_BLUE = 0x0001
|
||||
FOREGROUND_GREEN = 0x0002
|
||||
# FOREGROUND_CYAN = 0x0003
|
||||
FOREGROUND_RED = 0x0004
|
||||
FOREGROUND_MAGENTA = 0x0005
|
||||
FOREGROUND_YELLOW = 0x0006
|
||||
# FOREGROUND_GREY = 0x0007
|
||||
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
|
||||
|
||||
# BACKGROUND_BLACK = 0x0000
|
||||
# BACKGROUND_BLUE = 0x0010
|
||||
# BACKGROUND_GREEN = 0x0020
|
||||
# BACKGROUND_CYAN = 0x0030
|
||||
# BACKGROUND_RED = 0x0040
|
||||
# BACKGROUND_MAGENTA = 0x0050
|
||||
BACKGROUND_YELLOW = 0x0060
|
||||
# BACKGROUND_GREY = 0x0070
|
||||
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
|
||||
|
||||
levelno = args[1].levelno
|
||||
if (levelno >= 50):
|
||||
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
|
||||
elif (levelno >= 40):
|
||||
color = FOREGROUND_RED | FOREGROUND_INTENSITY
|
||||
elif (levelno >= 30):
|
||||
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
|
||||
elif (levelno >= 20):
|
||||
color = FOREGROUND_GREEN
|
||||
elif (levelno >= 10):
|
||||
color = FOREGROUND_MAGENTA
|
||||
else:
|
||||
color = FOREGROUND_WHITE
|
||||
args[0]._set_color(color)
|
||||
|
||||
ret = fn(*args)
|
||||
args[0]._set_color(FOREGROUND_WHITE)
|
||||
# print "after"
|
||||
return ret
|
||||
return new
|
||||
|
||||
|
||||
def add_coloring_to_emit_ansi(fn):
|
||||
# add methods we need to the class
|
||||
def new(*args):
|
||||
levelno = args[1].levelno
|
||||
if (levelno >= 50):
|
||||
color = '\x1b[31m' # red
|
||||
elif (levelno >= 40):
|
||||
color = '\x1b[31m' # red
|
||||
elif (levelno >= 30):
|
||||
color = '\x1b[33m' # yellow
|
||||
elif (levelno >= 20):
|
||||
color = '\x1b[32m' # green
|
||||
elif (levelno >= 10):
|
||||
color = '\x1b[35m' # pink
|
||||
else:
|
||||
color = '\x1b[0m' # normal
|
||||
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
|
||||
# print "after"
|
||||
return fn(*args)
|
||||
return new
|
||||
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
||||
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
||||
else:
|
||||
# all non-Windows platforms are supporting ANSI escapes so we use them
|
||||
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
|
||||
# log = logging.getLogger()
|
||||
# log.addFilter(log_filter())
|
||||
# //hdlr = logging.StreamHandler()
|
||||
# //hdlr.setFormatter(formatter())
|
||||
|
||||
logger = logging.getLogger('text-generation-webui')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def setup_logging():
|
||||
'''
|
||||
Copied from: https://github.com/vladmandic/automatic
|
||||
|
||||
All credits to vladmandic.
|
||||
'''
|
||||
|
||||
class RingBuffer(logging.StreamHandler):
|
||||
def __init__(self, capacity):
|
||||
super().__init__()
|
||||
self.capacity = capacity
|
||||
self.buffer = []
|
||||
self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }')
|
||||
|
||||
def emit(self, record):
|
||||
msg = self.format(record)
|
||||
# self.buffer.append(json.loads(msg))
|
||||
self.buffer.append(msg)
|
||||
if len(self.buffer) > self.capacity:
|
||||
self.buffer.pop(0)
|
||||
|
||||
def get(self):
|
||||
return self.buffer
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.pretty import install as pretty_install
|
||||
from rich.theme import Theme
|
||||
from rich.traceback import install as traceback_install
|
||||
|
||||
level = logging.DEBUG
|
||||
logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
|
||||
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
|
||||
"traceback.border": "black",
|
||||
"traceback.border.syntax_error": "black",
|
||||
"inspect.value.border": "black",
|
||||
}))
|
||||
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
|
||||
pretty_install(console=console)
|
||||
traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
|
||||
while logger.hasHandlers() and len(logger.handlers) > 0:
|
||||
logger.removeHandler(logger.handlers[0])
|
||||
|
||||
# handlers
|
||||
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
|
||||
rh.setLevel(level)
|
||||
logger.addHandler(rh)
|
||||
|
||||
rb = RingBuffer(100) # 100 entries default in log ring buffer
|
||||
rb.setLevel(level)
|
||||
logger.addHandler(rb)
|
||||
logger.buffer = rb.buffer
|
||||
|
||||
# overrides
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
logging.getLogger("torch").setLevel(logging.ERROR)
|
||||
logging.getLogger("lycoris").handlers = logger.handlers
|
||||
|
||||
|
||||
setup_logging()
|
||||
|
@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers()
|
||||
|
||||
|
||||
def load_model(model_name, loader=None):
|
||||
logger.info(f"Loading {model_name}...")
|
||||
logger.info(f"Loading {model_name}")
|
||||
t0 = time.time()
|
||||
|
||||
shared.is_seq2seq = False
|
||||
@ -413,8 +413,12 @@ def ExLlamav2_HF_loader(model_name):
|
||||
|
||||
|
||||
def HQQ_loader(model_name):
|
||||
from hqq.engine.hf import HQQModelForCausalLM
|
||||
from hqq.core.quantize import HQQLinear, HQQBackend
|
||||
try:
|
||||
from hqq.core.quantize import HQQBackend, HQQLinear
|
||||
from hqq.engine.hf import HQQModelForCausalLM
|
||||
except ModuleNotFoundError:
|
||||
logger.error("HQQ is not installed. You can install it with:\n\npip install hqq")
|
||||
return None
|
||||
|
||||
logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")
|
||||
|
||||
|
@ -204,22 +204,26 @@ for arg in sys.argv[1:]:
|
||||
if hasattr(args, arg):
|
||||
provided_arguments.append(arg)
|
||||
|
||||
# Deprecation warnings
|
||||
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
|
||||
for k in deprecated_args:
|
||||
if getattr(args, k):
|
||||
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
|
||||
|
||||
# Security warnings
|
||||
if args.trust_remote_code:
|
||||
logger.warning('trust_remote_code is enabled. This is dangerous.')
|
||||
if 'COLAB_GPU' not in os.environ and not args.nowebui:
|
||||
if args.share:
|
||||
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
||||
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
||||
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
||||
if args.multi_user:
|
||||
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
||||
|
||||
def do_cmd_flags_warnings():
|
||||
|
||||
# Deprecation warnings
|
||||
for k in deprecated_args:
|
||||
if getattr(args, k):
|
||||
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
|
||||
|
||||
# Security warnings
|
||||
if args.trust_remote_code:
|
||||
logger.warning('trust_remote_code is enabled. This is dangerous.')
|
||||
if 'COLAB_GPU' not in os.environ and not args.nowebui:
|
||||
if args.share:
|
||||
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
||||
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
||||
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
||||
if args.multi_user:
|
||||
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
||||
|
||||
|
||||
def fix_loader_name(name):
|
||||
|
@ -249,7 +249,7 @@ def backup_adapter(input_folder):
|
||||
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
||||
if adapter_file.is_file():
|
||||
|
||||
logger.info("Backing up existing LoRA adapter...")
|
||||
logger.info("Backing up existing LoRA adapter")
|
||||
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
|
||||
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
|
||||
|
||||
@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
# == Prep the dataset, format, etc ==
|
||||
if raw_text_file not in ['None', '']:
|
||||
train_template["template_type"] = "raw_text"
|
||||
logger.info("Loading raw text file dataset...")
|
||||
logger.info("Loading raw text file dataset")
|
||||
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
||||
fullpath = Path(fullpath)
|
||||
if fullpath.is_dir():
|
||||
@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
prompt = generate_prompt(data_point)
|
||||
return tokenize(prompt, add_eos_token)
|
||||
|
||||
logger.info("Loading JSON datasets...")
|
||||
logger.info("Loading JSON datasets")
|
||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
|
||||
@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
logger.info("Getting model ready...")
|
||||
logger.info("Getting model ready")
|
||||
prepare_model_for_kbit_training(shared.model)
|
||||
|
||||
# base model is now frozen and should not be reused for any other LoRA training than this one
|
||||
shared.model_dirty_from_training = True
|
||||
|
||||
logger.info("Preparing for training...")
|
||||
logger.info("Preparing for training")
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
|
||||
|
||||
try:
|
||||
logger.info("Creating LoRA model...")
|
||||
logger.info("Creating LoRA model")
|
||||
lora_model = get_peft_model(shared.model, config)
|
||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
||||
logger.info("Loading existing LoRA data...")
|
||||
logger.info("Loading existing LoRA data")
|
||||
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
|
||||
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||
except:
|
||||
@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
json.dump(train_template, file, indent=2)
|
||||
|
||||
# == Main run and monitor loop ==
|
||||
logger.info("Starting training...")
|
||||
logger.info("Starting training")
|
||||
yield "Starting..."
|
||||
|
||||
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
||||
@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
|
||||
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||
if not tracked.did_save:
|
||||
logger.info("Training complete, saving...")
|
||||
logger.info("Training complete, saving")
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
|
||||
if WANT_INTERRUPT:
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -4,7 +4,6 @@ datasets
|
||||
einops
|
||||
exllamav2==0.0.11
|
||||
gradio==3.50.*
|
||||
hqq==0.1.1
|
||||
markdown
|
||||
numpy==1.24.*
|
||||
optimum==1.16.*
|
||||
|
@ -12,6 +12,8 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
|
||||
|
||||
with RequestBlocker():
|
||||
import gradio as gr
|
||||
@ -54,6 +56,7 @@ from modules.models_settings import (
|
||||
get_model_metadata,
|
||||
update_model_parameters
|
||||
)
|
||||
from modules.shared import do_cmd_flags_warnings
|
||||
from modules.utils import gradio
|
||||
|
||||
|
||||
@ -170,6 +173,9 @@ def create_interface():
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
logger.info("Starting Text generation web UI")
|
||||
do_cmd_flags_warnings()
|
||||
|
||||
# Load custom settings
|
||||
settings_file = None
|
||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||
@ -180,7 +186,7 @@ if __name__ == "__main__":
|
||||
settings_file = Path('settings.json')
|
||||
|
||||
if settings_file is not None:
|
||||
logger.info(f"Loading settings from {settings_file}...")
|
||||
logger.info(f"Loading settings from {settings_file}")
|
||||
file_contents = open(settings_file, 'r', encoding='utf-8').read()
|
||||
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
|
||||
shared.settings.update(new_settings)
|
||||
|
Loading…
Reference in New Issue
Block a user