Merge pull request #5011 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2023-12-20 02:38:25 -03:00 committed by GitHub
commit c1f78dbd0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 120 additions and 171 deletions

View File

@ -6,26 +6,12 @@ import time
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from modules import chat, shared, ui_chat
from modules.logging_colors import logger
from modules.ui import create_refresh_button
from modules.utils import gradio
try:
from TTS.api import TTS from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer
except ModuleNotFoundError:
logger.error(
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
"\n"
"\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n"
"\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n"
"\n"
"If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"."
)
raise from modules import chat, shared, ui_chat
from modules.ui import create_refresh_button
from modules.utils import gradio
os.environ["COQUI_TOS_AGREED"] = "1" os.environ["COQUI_TOS_AGREED"] = "1"

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import json import json
import logging
import os import os
import traceback import traceback
from threading import Thread from threading import Thread
@ -367,6 +368,7 @@ def run_server():
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key: if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n') logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
logging.getLogger("uvicorn.error").propagate = False
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)

View File

@ -126,7 +126,7 @@ def load_quantized(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name) pt_path = find_quantized_model_file(model_name)
if not pt_path: if not pt_path:
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...") logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
exit() exit()
else: else:
logger.info(f"Found the following quantized model: {pt_path}") logger.info(f"Found the following quantized model: {pt_path}")

View File

@ -138,7 +138,7 @@ def add_lora_transformers(lora_names):
# Add a LoRA when another LoRA is already present # Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys(): if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
logger.info(f"Adding the LoRA(s) named {added_set} to the model...") logger.info(f"Adding the LoRA(s) named {added_set} to the model")
for lora in added_set: for lora in added_set:
shared.model.load_adapter(get_lora_path(lora), lora) shared.model.load_adapter(get_lora_path(lora), lora)

View File

@ -95,7 +95,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
else: else:
renderer = chat_renderer renderer = chat_renderer
if state['context'].strip() != '': if state['context'].strip() != '':
messages.append({"role": "system", "content": state['context']}) context = replace_character_names(state['context'], state['name1'], state['name2'])
messages.append({"role": "system", "content": context})
insert_pos = len(messages) insert_pos = len(messages)
for user_msg, assistant_msg in reversed(history): for user_msg, assistant_msg in reversed(history):

View File

@ -31,9 +31,14 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
if name != 'api': if name != 'api':
logger.info(f'Loading the extension "{name}"...') logger.info(f'Loading the extension "{name}"')
try:
try: try:
exec(f"import extensions.{name}.script") exec(f"import extensions.{name}.script")
except ModuleNotFoundError:
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.")
raise
extension = getattr(extensions, name).script extension = getattr(extensions, name).script
apply_settings(extension, name) apply_settings(extension, name)
if extension not in setup_called and hasattr(extension, "setup"): if extension not in setup_called and hasattr(extension, "setup"):

View File

@ -1,117 +1,67 @@
# Copied from https://stackoverflow.com/a/1336640
import logging import logging
import platform
logging.basicConfig(
format='%(asctime)s %(levelname)s:%(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
def add_coloring_to_emit_windows(fn):
# add methods we need to the class
def _out_handle(self):
import ctypes
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
out_handle = property(_out_handle)
def _set_color(self, code):
import ctypes
# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
setattr(logging.StreamHandler, '_set_color', _set_color)
def new(*args):
FOREGROUND_BLUE = 0x0001 # text color contains blue.
FOREGROUND_GREEN = 0x0002 # text color contains green.
FOREGROUND_RED = 0x0004 # text color contains red.
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
# winbase.h
# STD_INPUT_HANDLE = -10
# STD_OUTPUT_HANDLE = -11
# STD_ERROR_HANDLE = -12
# wincon.h
# FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
# FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
# FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
# BACKGROUND_BLACK = 0x0000
# BACKGROUND_BLUE = 0x0010
# BACKGROUND_GREEN = 0x0020
# BACKGROUND_CYAN = 0x0030
# BACKGROUND_RED = 0x0040
# BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
# BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
levelno = args[1].levelno
if (levelno >= 50):
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
elif (levelno >= 40):
color = FOREGROUND_RED | FOREGROUND_INTENSITY
elif (levelno >= 30):
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
elif (levelno >= 20):
color = FOREGROUND_GREEN
elif (levelno >= 10):
color = FOREGROUND_MAGENTA
else:
color = FOREGROUND_WHITE
args[0]._set_color(color)
ret = fn(*args)
args[0]._set_color(FOREGROUND_WHITE)
# print "after"
return ret
return new
def add_coloring_to_emit_ansi(fn):
# add methods we need to the class
def new(*args):
levelno = args[1].levelno
if (levelno >= 50):
color = '\x1b[31m' # red
elif (levelno >= 40):
color = '\x1b[31m' # red
elif (levelno >= 30):
color = '\x1b[33m' # yellow
elif (levelno >= 20):
color = '\x1b[32m' # green
elif (levelno >= 10):
color = '\x1b[35m' # pink
else:
color = '\x1b[0m' # normal
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
# print "after"
return fn(*args)
return new
if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
else:
# all non-Windows platforms are supporting ANSI escapes so we use them
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
# log = logging.getLogger()
# log.addFilter(log_filter())
# //hdlr = logging.StreamHandler()
# //hdlr.setFormatter(formatter())
logger = logging.getLogger('text-generation-webui') logger = logging.getLogger('text-generation-webui')
logger.setLevel(logging.DEBUG)
def setup_logging():
'''
Copied from: https://github.com/vladmandic/automatic
All credits to vladmandic.
'''
class RingBuffer(logging.StreamHandler):
def __init__(self, capacity):
super().__init__()
self.capacity = capacity
self.buffer = []
self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }')
def emit(self, record):
msg = self.format(record)
# self.buffer.append(json.loads(msg))
self.buffer.append(msg)
if len(self.buffer) > self.capacity:
self.buffer.pop(0)
def get(self):
return self.buffer
from rich.console import Console
from rich.logging import RichHandler
from rich.pretty import install as pretty_install
from rich.theme import Theme
from rich.traceback import install as traceback_install
level = logging.DEBUG
logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
"traceback.border": "black",
"traceback.border.syntax_error": "black",
"inspect.value.border": "black",
}))
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
pretty_install(console=console)
traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
while logger.hasHandlers() and len(logger.handlers) > 0:
logger.removeHandler(logger.handlers[0])
# handlers
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
rh.setLevel(level)
logger.addHandler(rh)
rb = RingBuffer(100) # 100 entries default in log ring buffer
rb.setLevel(level)
logger.addHandler(rb)
logger.buffer = rb.buffer
# overrides
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("diffusers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("lycoris").handlers = logger.handlers
setup_logging()

View File

@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers()
def load_model(model_name, loader=None): def load_model(model_name, loader=None):
logger.info(f"Loading {model_name}...") logger.info(f"Loading {model_name}")
t0 = time.time() t0 = time.time()
shared.is_seq2seq = False shared.is_seq2seq = False
@ -413,8 +413,12 @@ def ExLlamav2_HF_loader(model_name):
def HQQ_loader(model_name): def HQQ_loader(model_name):
try:
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.engine.hf import HQQModelForCausalLM from hqq.engine.hf import HQQModelForCausalLM
from hqq.core.quantize import HQQLinear, HQQBackend except ModuleNotFoundError:
logger.error("HQQ is not installed. You can install it with:\n\npip install hqq")
return None
logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}") logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")

View File

@ -204,8 +204,12 @@ for arg in sys.argv[1:]:
if hasattr(args, arg): if hasattr(args, arg):
provided_arguments.append(arg) provided_arguments.append(arg)
# Deprecation warnings
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast'] deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
def do_cmd_flags_warnings():
# Deprecation warnings
for k in deprecated_args: for k in deprecated_args:
if getattr(args, k): if getattr(args, k):
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.') logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')

View File

@ -249,7 +249,7 @@ def backup_adapter(input_folder):
adapter_file = Path(f"{input_folder}/adapter_model.bin") adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file(): if adapter_file.is_file():
logger.info("Backing up existing LoRA adapter...") logger.info("Backing up existing LoRA adapter")
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime) creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']: if raw_text_file not in ['None', '']:
train_template["template_type"] = "raw_text" train_template["template_type"] = "raw_text"
logger.info("Loading raw text file dataset...") logger.info("Loading raw text file dataset")
fullpath = clean_path('training/datasets', f'{raw_text_file}') fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath) fullpath = Path(fullpath)
if fullpath.is_dir(): if fullpath.is_dir():
@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
prompt = generate_prompt(data_point) prompt = generate_prompt(data_point)
return tokenize(prompt, add_eos_token) return tokenize(prompt, add_eos_token)
logger.info("Loading JSON datasets...") logger.info("Loading JSON datasets")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30)) train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Start prepping the model itself == # == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...") logger.info("Getting model ready")
prepare_model_for_kbit_training(shared.model) prepare_model_for_kbit_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one # base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True shared.model_dirty_from_training = True
logger.info("Preparing for training...") logger.info("Preparing for training")
config = LoraConfig( config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
try: try:
logger.info("Creating LoRA model...") logger.info("Creating LoRA model")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logger.info("Loading existing LoRA data...") logger.info("Loading existing LoRA data")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True) state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
set_peft_model_state_dict(lora_model, state_dict_peft) set_peft_model_state_dict(lora_model, state_dict_peft)
except: except:
@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
json.dump(train_template, file, indent=2) json.dump(train_template, file, indent=2)
# == Main run and monitor loop == # == Main run and monitor loop ==
logger.info("Starting training...") logger.info("Starting training")
yield "Starting..." yield "Starting..."
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# Saving in the train thread might fail if an error occurs, so save here if so. # Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save: if not tracked.did_save:
logger.info("Training complete, saving...") logger.info("Training complete, saving")
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT: if WANT_INTERRUPT:

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64" exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64" exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -4,7 +4,6 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.16.* optimum==1.16.*

View File

@ -12,6 +12,8 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
with RequestBlocker(): with RequestBlocker():
import gradio as gr import gradio as gr
@ -54,6 +56,7 @@ from modules.models_settings import (
get_model_metadata, get_model_metadata,
update_model_parameters update_model_parameters
) )
from modules.shared import do_cmd_flags_warnings
from modules.utils import gradio from modules.utils import gradio
@ -170,6 +173,9 @@ def create_interface():
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Starting Text generation web UI")
do_cmd_flags_warnings()
# Load custom settings # Load custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -180,7 +186,7 @@ if __name__ == "__main__":
settings_file = Path('settings.json') settings_file = Path('settings.json')
if settings_file is not None: if settings_file is not None:
logger.info(f"Loading settings from {settings_file}...") logger.info(f"Loading settings from {settings_file}")
file_contents = open(settings_file, 'r', encoding='utf-8').read() file_contents = open(settings_file, 'r', encoding='utf-8').read()
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents) new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
shared.settings.update(new_settings) shared.settings.update(new_settings)