mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 16:38:21 +01:00
commit
c1f78dbd0f
@ -6,27 +6,13 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from TTS.api import TTS
|
||||||
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
|
|
||||||
from modules import chat, shared, ui_chat
|
from modules import chat, shared, ui_chat
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules.ui import create_refresh_button
|
from modules.ui import create_refresh_button
|
||||||
from modules.utils import gradio
|
from modules.utils import gradio
|
||||||
|
|
||||||
try:
|
|
||||||
from TTS.api import TTS
|
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
logger.error(
|
|
||||||
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
|
|
||||||
"\n"
|
|
||||||
"\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n"
|
|
||||||
"\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n"
|
|
||||||
"\n"
|
|
||||||
"If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"."
|
|
||||||
)
|
|
||||||
|
|
||||||
raise
|
|
||||||
|
|
||||||
os.environ["COQUI_TOS_AGREED"] = "1"
|
os.environ["COQUI_TOS_AGREED"] = "1"
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -367,6 +368,7 @@ def run_server():
|
|||||||
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
|
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
|
||||||
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
|
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
|
||||||
|
|
||||||
|
logging.getLogger("uvicorn.error").propagate = False
|
||||||
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ def load_quantized(model_name):
|
|||||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
pt_path = find_quantized_model_file(model_name)
|
pt_path = find_quantized_model_file(model_name)
|
||||||
if not pt_path:
|
if not pt_path:
|
||||||
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
|
||||||
exit()
|
exit()
|
||||||
else:
|
else:
|
||||||
logger.info(f"Found the following quantized model: {pt_path}")
|
logger.info(f"Found the following quantized model: {pt_path}")
|
||||||
|
@ -138,7 +138,7 @@ def add_lora_transformers(lora_names):
|
|||||||
|
|
||||||
# Add a LoRA when another LoRA is already present
|
# Add a LoRA when another LoRA is already present
|
||||||
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
|
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
|
||||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
logger.info(f"Adding the LoRA(s) named {added_set} to the model")
|
||||||
for lora in added_set:
|
for lora in added_set:
|
||||||
shared.model.load_adapter(get_lora_path(lora), lora)
|
shared.model.load_adapter(get_lora_path(lora), lora)
|
||||||
|
|
||||||
|
@ -95,7 +95,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
else:
|
else:
|
||||||
renderer = chat_renderer
|
renderer = chat_renderer
|
||||||
if state['context'].strip() != '':
|
if state['context'].strip() != '':
|
||||||
messages.append({"role": "system", "content": state['context']})
|
context = replace_character_names(state['context'], state['name1'], state['name2'])
|
||||||
|
messages.append({"role": "system", "content": context})
|
||||||
|
|
||||||
insert_pos = len(messages)
|
insert_pos = len(messages)
|
||||||
for user_msg, assistant_msg in reversed(history):
|
for user_msg, assistant_msg in reversed(history):
|
||||||
|
@ -31,9 +31,14 @@ def load_extensions():
|
|||||||
for i, name in enumerate(shared.args.extensions):
|
for i, name in enumerate(shared.args.extensions):
|
||||||
if name in available_extensions:
|
if name in available_extensions:
|
||||||
if name != 'api':
|
if name != 'api':
|
||||||
logger.info(f'Loading the extension "{name}"...')
|
logger.info(f'Loading the extension "{name}"')
|
||||||
|
try:
|
||||||
try:
|
try:
|
||||||
exec(f"import extensions.{name}.script")
|
exec(f"import extensions.{name}.script")
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.")
|
||||||
|
raise
|
||||||
|
|
||||||
extension = getattr(extensions, name).script
|
extension = getattr(extensions, name).script
|
||||||
apply_settings(extension, name)
|
apply_settings(extension, name)
|
||||||
if extension not in setup_called and hasattr(extension, "setup"):
|
if extension not in setup_called and hasattr(extension, "setup"):
|
||||||
|
@ -1,117 +1,67 @@
|
|||||||
# Copied from https://stackoverflow.com/a/1336640
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import platform
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format='%(asctime)s %(levelname)s:%(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_coloring_to_emit_windows(fn):
|
|
||||||
# add methods we need to the class
|
|
||||||
def _out_handle(self):
|
|
||||||
import ctypes
|
|
||||||
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
|
||||||
out_handle = property(_out_handle)
|
|
||||||
|
|
||||||
def _set_color(self, code):
|
|
||||||
import ctypes
|
|
||||||
|
|
||||||
# Constants from the Windows API
|
|
||||||
self.STD_OUTPUT_HANDLE = -11
|
|
||||||
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
|
||||||
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
|
|
||||||
|
|
||||||
setattr(logging.StreamHandler, '_set_color', _set_color)
|
|
||||||
|
|
||||||
def new(*args):
|
|
||||||
FOREGROUND_BLUE = 0x0001 # text color contains blue.
|
|
||||||
FOREGROUND_GREEN = 0x0002 # text color contains green.
|
|
||||||
FOREGROUND_RED = 0x0004 # text color contains red.
|
|
||||||
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
|
|
||||||
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
|
|
||||||
# winbase.h
|
|
||||||
# STD_INPUT_HANDLE = -10
|
|
||||||
# STD_OUTPUT_HANDLE = -11
|
|
||||||
# STD_ERROR_HANDLE = -12
|
|
||||||
|
|
||||||
# wincon.h
|
|
||||||
# FOREGROUND_BLACK = 0x0000
|
|
||||||
FOREGROUND_BLUE = 0x0001
|
|
||||||
FOREGROUND_GREEN = 0x0002
|
|
||||||
# FOREGROUND_CYAN = 0x0003
|
|
||||||
FOREGROUND_RED = 0x0004
|
|
||||||
FOREGROUND_MAGENTA = 0x0005
|
|
||||||
FOREGROUND_YELLOW = 0x0006
|
|
||||||
# FOREGROUND_GREY = 0x0007
|
|
||||||
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
|
|
||||||
|
|
||||||
# BACKGROUND_BLACK = 0x0000
|
|
||||||
# BACKGROUND_BLUE = 0x0010
|
|
||||||
# BACKGROUND_GREEN = 0x0020
|
|
||||||
# BACKGROUND_CYAN = 0x0030
|
|
||||||
# BACKGROUND_RED = 0x0040
|
|
||||||
# BACKGROUND_MAGENTA = 0x0050
|
|
||||||
BACKGROUND_YELLOW = 0x0060
|
|
||||||
# BACKGROUND_GREY = 0x0070
|
|
||||||
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
|
|
||||||
|
|
||||||
levelno = args[1].levelno
|
|
||||||
if (levelno >= 50):
|
|
||||||
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
|
|
||||||
elif (levelno >= 40):
|
|
||||||
color = FOREGROUND_RED | FOREGROUND_INTENSITY
|
|
||||||
elif (levelno >= 30):
|
|
||||||
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
|
|
||||||
elif (levelno >= 20):
|
|
||||||
color = FOREGROUND_GREEN
|
|
||||||
elif (levelno >= 10):
|
|
||||||
color = FOREGROUND_MAGENTA
|
|
||||||
else:
|
|
||||||
color = FOREGROUND_WHITE
|
|
||||||
args[0]._set_color(color)
|
|
||||||
|
|
||||||
ret = fn(*args)
|
|
||||||
args[0]._set_color(FOREGROUND_WHITE)
|
|
||||||
# print "after"
|
|
||||||
return ret
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
def add_coloring_to_emit_ansi(fn):
|
|
||||||
# add methods we need to the class
|
|
||||||
def new(*args):
|
|
||||||
levelno = args[1].levelno
|
|
||||||
if (levelno >= 50):
|
|
||||||
color = '\x1b[31m' # red
|
|
||||||
elif (levelno >= 40):
|
|
||||||
color = '\x1b[31m' # red
|
|
||||||
elif (levelno >= 30):
|
|
||||||
color = '\x1b[33m' # yellow
|
|
||||||
elif (levelno >= 20):
|
|
||||||
color = '\x1b[32m' # green
|
|
||||||
elif (levelno >= 10):
|
|
||||||
color = '\x1b[35m' # pink
|
|
||||||
else:
|
|
||||||
color = '\x1b[0m' # normal
|
|
||||||
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
|
|
||||||
# print "after"
|
|
||||||
return fn(*args)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
if platform.system() == 'Windows':
|
|
||||||
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
|
||||||
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
|
||||||
else:
|
|
||||||
# all non-Windows platforms are supporting ANSI escapes so we use them
|
|
||||||
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
|
|
||||||
# log = logging.getLogger()
|
|
||||||
# log.addFilter(log_filter())
|
|
||||||
# //hdlr = logging.StreamHandler()
|
|
||||||
# //hdlr.setFormatter(formatter())
|
|
||||||
|
|
||||||
logger = logging.getLogger('text-generation-webui')
|
logger = logging.getLogger('text-generation-webui')
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
'''
|
||||||
|
Copied from: https://github.com/vladmandic/automatic
|
||||||
|
|
||||||
|
All credits to vladmandic.
|
||||||
|
'''
|
||||||
|
|
||||||
|
class RingBuffer(logging.StreamHandler):
|
||||||
|
def __init__(self, capacity):
|
||||||
|
super().__init__()
|
||||||
|
self.capacity = capacity
|
||||||
|
self.buffer = []
|
||||||
|
self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }')
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
msg = self.format(record)
|
||||||
|
# self.buffer.append(json.loads(msg))
|
||||||
|
self.buffer.append(msg)
|
||||||
|
if len(self.buffer) > self.capacity:
|
||||||
|
self.buffer.pop(0)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return self.buffer
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
from rich.pretty import install as pretty_install
|
||||||
|
from rich.theme import Theme
|
||||||
|
from rich.traceback import install as traceback_install
|
||||||
|
|
||||||
|
level = logging.DEBUG
|
||||||
|
logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
|
||||||
|
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
|
||||||
|
"traceback.border": "black",
|
||||||
|
"traceback.border.syntax_error": "black",
|
||||||
|
"inspect.value.border": "black",
|
||||||
|
}))
|
||||||
|
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
|
||||||
|
pretty_install(console=console)
|
||||||
|
traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
|
||||||
|
while logger.hasHandlers() and len(logger.handlers) > 0:
|
||||||
|
logger.removeHandler(logger.handlers[0])
|
||||||
|
|
||||||
|
# handlers
|
||||||
|
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
|
||||||
|
rh.setLevel(level)
|
||||||
|
logger.addHandler(rh)
|
||||||
|
|
||||||
|
rb = RingBuffer(100) # 100 entries default in log ring buffer
|
||||||
|
rb.setLevel(level)
|
||||||
|
logger.addHandler(rb)
|
||||||
|
logger.buffer = rb.buffer
|
||||||
|
|
||||||
|
# overrides
|
||||||
|
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("torch").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("lycoris").handlers = logger.handlers
|
||||||
|
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers()
|
|||||||
|
|
||||||
|
|
||||||
def load_model(model_name, loader=None):
|
def load_model(model_name, loader=None):
|
||||||
logger.info(f"Loading {model_name}...")
|
logger.info(f"Loading {model_name}")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.is_seq2seq = False
|
shared.is_seq2seq = False
|
||||||
@ -413,8 +413,12 @@ def ExLlamav2_HF_loader(model_name):
|
|||||||
|
|
||||||
|
|
||||||
def HQQ_loader(model_name):
|
def HQQ_loader(model_name):
|
||||||
|
try:
|
||||||
|
from hqq.core.quantize import HQQBackend, HQQLinear
|
||||||
from hqq.engine.hf import HQQModelForCausalLM
|
from hqq.engine.hf import HQQModelForCausalLM
|
||||||
from hqq.core.quantize import HQQLinear, HQQBackend
|
except ModuleNotFoundError:
|
||||||
|
logger.error("HQQ is not installed. You can install it with:\n\npip install hqq")
|
||||||
|
return None
|
||||||
|
|
||||||
logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")
|
logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")
|
||||||
|
|
||||||
|
@ -204,16 +204,20 @@ for arg in sys.argv[1:]:
|
|||||||
if hasattr(args, arg):
|
if hasattr(args, arg):
|
||||||
provided_arguments.append(arg)
|
provided_arguments.append(arg)
|
||||||
|
|
||||||
# Deprecation warnings
|
|
||||||
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
|
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
|
||||||
for k in deprecated_args:
|
|
||||||
|
|
||||||
|
def do_cmd_flags_warnings():
|
||||||
|
|
||||||
|
# Deprecation warnings
|
||||||
|
for k in deprecated_args:
|
||||||
if getattr(args, k):
|
if getattr(args, k):
|
||||||
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
|
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
|
||||||
|
|
||||||
# Security warnings
|
# Security warnings
|
||||||
if args.trust_remote_code:
|
if args.trust_remote_code:
|
||||||
logger.warning('trust_remote_code is enabled. This is dangerous.')
|
logger.warning('trust_remote_code is enabled. This is dangerous.')
|
||||||
if 'COLAB_GPU' not in os.environ and not args.nowebui:
|
if 'COLAB_GPU' not in os.environ and not args.nowebui:
|
||||||
if args.share:
|
if args.share:
|
||||||
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
||||||
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
||||||
|
@ -249,7 +249,7 @@ def backup_adapter(input_folder):
|
|||||||
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
||||||
if adapter_file.is_file():
|
if adapter_file.is_file():
|
||||||
|
|
||||||
logger.info("Backing up existing LoRA adapter...")
|
logger.info("Backing up existing LoRA adapter")
|
||||||
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
|
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
|
||||||
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
|
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
|
||||||
|
|
||||||
@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
if raw_text_file not in ['None', '']:
|
||||||
train_template["template_type"] = "raw_text"
|
train_template["template_type"] = "raw_text"
|
||||||
logger.info("Loading raw text file dataset...")
|
logger.info("Loading raw text file dataset")
|
||||||
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
||||||
fullpath = Path(fullpath)
|
fullpath = Path(fullpath)
|
||||||
if fullpath.is_dir():
|
if fullpath.is_dir():
|
||||||
@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
prompt = generate_prompt(data_point)
|
prompt = generate_prompt(data_point)
|
||||||
return tokenize(prompt, add_eos_token)
|
return tokenize(prompt, add_eos_token)
|
||||||
|
|
||||||
logger.info("Loading JSON datasets...")
|
logger.info("Loading JSON datasets")
|
||||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||||
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||||
|
|
||||||
@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
|
|
||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
logger.info("Getting model ready...")
|
logger.info("Getting model ready")
|
||||||
prepare_model_for_kbit_training(shared.model)
|
prepare_model_for_kbit_training(shared.model)
|
||||||
|
|
||||||
# base model is now frozen and should not be reused for any other LoRA training than this one
|
# base model is now frozen and should not be reused for any other LoRA training than this one
|
||||||
shared.model_dirty_from_training = True
|
shared.model_dirty_from_training = True
|
||||||
|
|
||||||
logger.info("Preparing for training...")
|
logger.info("Preparing for training")
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
|
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Creating LoRA model...")
|
logger.info("Creating LoRA model")
|
||||||
lora_model = get_peft_model(shared.model, config)
|
lora_model = get_peft_model(shared.model, config)
|
||||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
||||||
logger.info("Loading existing LoRA data...")
|
logger.info("Loading existing LoRA data")
|
||||||
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
|
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
|
||||||
set_peft_model_state_dict(lora_model, state_dict_peft)
|
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||||
except:
|
except:
|
||||||
@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
json.dump(train_template, file, indent=2)
|
json.dump(train_template, file, indent=2)
|
||||||
|
|
||||||
# == Main run and monitor loop ==
|
# == Main run and monitor loop ==
|
||||||
logger.info("Starting training...")
|
logger.info("Starting training")
|
||||||
yield "Starting..."
|
yield "Starting..."
|
||||||
|
|
||||||
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
||||||
@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
|
|
||||||
# Saving in the train thread might fail if an error occurs, so save here if so.
|
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||||
if not tracked.did_save:
|
if not tracked.did_save:
|
||||||
logger.info("Training complete, saving...")
|
logger.info("Training complete, saving")
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
|
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
|
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
|
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11
|
exllamav2==0.0.11
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11
|
exllamav2==0.0.11
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11
|
exllamav2==0.0.11
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11
|
exllamav2==0.0.11
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
|
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -4,7 +4,6 @@ datasets
|
|||||||
einops
|
einops
|
||||||
exllamav2==0.0.11
|
exllamav2==0.0.11
|
||||||
gradio==3.50.*
|
gradio==3.50.*
|
||||||
hqq==0.1.1
|
|
||||||
markdown
|
markdown
|
||||||
numpy==1.24.*
|
numpy==1.24.*
|
||||||
optimum==1.16.*
|
optimum==1.16.*
|
||||||
|
@ -12,6 +12,8 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
|||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
|
||||||
|
|
||||||
with RequestBlocker():
|
with RequestBlocker():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -54,6 +56,7 @@ from modules.models_settings import (
|
|||||||
get_model_metadata,
|
get_model_metadata,
|
||||||
update_model_parameters
|
update_model_parameters
|
||||||
)
|
)
|
||||||
|
from modules.shared import do_cmd_flags_warnings
|
||||||
from modules.utils import gradio
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
@ -170,6 +173,9 @@ def create_interface():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
logger.info("Starting Text generation web UI")
|
||||||
|
do_cmd_flags_warnings()
|
||||||
|
|
||||||
# Load custom settings
|
# Load custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||||
@ -180,7 +186,7 @@ if __name__ == "__main__":
|
|||||||
settings_file = Path('settings.json')
|
settings_file = Path('settings.json')
|
||||||
|
|
||||||
if settings_file is not None:
|
if settings_file is not None:
|
||||||
logger.info(f"Loading settings from {settings_file}...")
|
logger.info(f"Loading settings from {settings_file}")
|
||||||
file_contents = open(settings_file, 'r', encoding='utf-8').read()
|
file_contents = open(settings_file, 'r', encoding='utf-8').read()
|
||||||
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
|
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
|
||||||
shared.settings.update(new_settings)
|
shared.settings.update(new_settings)
|
||||||
|
Loading…
Reference in New Issue
Block a user