From d8279dc71083a1b150bde340f95d17b2cb1d6a75 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Dec 2023 17:31:46 -0800 Subject: [PATCH 1/8] Replace character name placeholders in chat context (closes #5007) --- modules/chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/chat.py b/modules/chat.py index 973a7fbd..8ddb7531 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -95,7 +95,8 @@ def generate_chat_prompt(user_input, state, **kwargs): else: renderer = chat_renderer if state['context'].strip() != '': - messages.append({"role": "system", "content": state['context']}) + context = replace_character_names(state['context'], state['name1'], state['name2']) + messages.append({"role": "system", "content": context}) insert_pos = len(messages) for user_msg, assistant_msg in reversed(history): From 95600073bcd0430d2195bd82e09d82398455ae03 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Dec 2023 20:20:45 -0800 Subject: [PATCH 2/8] Add an informative error when extension requirements are missing --- extensions/coqui_tts/script.py | 17 ++--------------- modules/extensions.py | 7 ++++++- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/extensions/coqui_tts/script.py b/extensions/coqui_tts/script.py index 8eab7afb..3b241c58 100644 --- a/extensions/coqui_tts/script.py +++ b/extensions/coqui_tts/script.py @@ -6,27 +6,14 @@ import time from pathlib import Path import gradio as gr +from TTS.api import TTS +from TTS.utils.synthesizer import Synthesizer from modules import chat, shared, ui_chat from modules.logging_colors import logger from modules.ui import create_refresh_button from modules.utils import gradio -try: - from TTS.api import TTS - from TTS.utils.synthesizer import Synthesizer -except ModuleNotFoundError: - logger.error( - "Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension." - "\n" - "\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n" - "\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n" - "\n" - "If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"." - ) - - raise - os.environ["COQUI_TOS_AGREED"] = "1" params = { diff --git a/modules/extensions.py b/modules/extensions.py index 6c072504..25fcc0a3 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -33,7 +33,12 @@ def load_extensions(): if name != 'api': logger.info(f'Loading the extension "{name}"...') try: - exec(f"import extensions.{name}.script") + try: + exec(f"import extensions.{name}.script") + except ModuleNotFoundError: + logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.") + raise + extension = getattr(extensions, name).script apply_settings(extension, name) if extension not in setup_called and hasattr(extension, "setup"): From 23818dc0981a0dbf3a49bef3beb015929f8c81f2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Dec 2023 20:38:33 -0800 Subject: [PATCH 3/8] Better logger Credits: vladmandic/automatic --- extensions/openai/script.py | 2 + modules/logging_colors.py | 178 +++++++++++++----------------------- 2 files changed, 66 insertions(+), 114 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 1958c30f..0be83442 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,5 +1,6 @@ import asyncio import json +import logging import os import traceback from threading import Thread @@ -367,6 +368,7 @@ def run_server(): if shared.args.admin_key and shared.args.admin_key != shared.args.api_key: logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n') + logging.getLogger("uvicorn.error").propagate = False uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) diff --git a/modules/logging_colors.py b/modules/logging_colors.py index a0c97c3a..ba760a2c 100644 --- a/modules/logging_colors.py +++ b/modules/logging_colors.py @@ -1,117 +1,67 @@ -# Copied from https://stackoverflow.com/a/1336640 - import logging -import platform - -logging.basicConfig( - format='%(asctime)s %(levelname)s:%(message)s', - datefmt='%Y-%m-%d %H:%M:%S', -) - - -def add_coloring_to_emit_windows(fn): - # add methods we need to the class - def _out_handle(self): - import ctypes - return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE) - out_handle = property(_out_handle) - - def _set_color(self, code): - import ctypes - - # Constants from the Windows API - self.STD_OUTPUT_HANDLE = -11 - hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE) - ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code) - - setattr(logging.StreamHandler, '_set_color', _set_color) - - def new(*args): - FOREGROUND_BLUE = 0x0001 # text color contains blue. - FOREGROUND_GREEN = 0x0002 # text color contains green. - FOREGROUND_RED = 0x0004 # text color contains red. - FOREGROUND_INTENSITY = 0x0008 # text color is intensified. - FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED - # winbase.h - # STD_INPUT_HANDLE = -10 - # STD_OUTPUT_HANDLE = -11 - # STD_ERROR_HANDLE = -12 - - # wincon.h - # FOREGROUND_BLACK = 0x0000 - FOREGROUND_BLUE = 0x0001 - FOREGROUND_GREEN = 0x0002 - # FOREGROUND_CYAN = 0x0003 - FOREGROUND_RED = 0x0004 - FOREGROUND_MAGENTA = 0x0005 - FOREGROUND_YELLOW = 0x0006 - # FOREGROUND_GREY = 0x0007 - FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified. - - # BACKGROUND_BLACK = 0x0000 - # BACKGROUND_BLUE = 0x0010 - # BACKGROUND_GREEN = 0x0020 - # BACKGROUND_CYAN = 0x0030 - # BACKGROUND_RED = 0x0040 - # BACKGROUND_MAGENTA = 0x0050 - BACKGROUND_YELLOW = 0x0060 - # BACKGROUND_GREY = 0x0070 - BACKGROUND_INTENSITY = 0x0080 # background color is intensified. - - levelno = args[1].levelno - if (levelno >= 50): - color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY - elif (levelno >= 40): - color = FOREGROUND_RED | FOREGROUND_INTENSITY - elif (levelno >= 30): - color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY - elif (levelno >= 20): - color = FOREGROUND_GREEN - elif (levelno >= 10): - color = FOREGROUND_MAGENTA - else: - color = FOREGROUND_WHITE - args[0]._set_color(color) - - ret = fn(*args) - args[0]._set_color(FOREGROUND_WHITE) - # print "after" - return ret - return new - - -def add_coloring_to_emit_ansi(fn): - # add methods we need to the class - def new(*args): - levelno = args[1].levelno - if (levelno >= 50): - color = '\x1b[31m' # red - elif (levelno >= 40): - color = '\x1b[31m' # red - elif (levelno >= 30): - color = '\x1b[33m' # yellow - elif (levelno >= 20): - color = '\x1b[32m' # green - elif (levelno >= 10): - color = '\x1b[35m' # pink - else: - color = '\x1b[0m' # normal - args[1].msg = color + args[1].msg + '\x1b[0m' # normal - # print "after" - return fn(*args) - return new - - -if platform.system() == 'Windows': - # Windows does not support ANSI escapes and we are using API calls to set the console color - logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit) -else: - # all non-Windows platforms are supporting ANSI escapes so we use them - logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit) - # log = logging.getLogger() - # log.addFilter(log_filter()) - # //hdlr = logging.StreamHandler() - # //hdlr.setFormatter(formatter()) logger = logging.getLogger('text-generation-webui') -logger.setLevel(logging.DEBUG) + + +def setup_logging(): + ''' + Copied from: https://github.com/vladmandic/automatic + + All credits to vladmandic. + ''' + + class RingBuffer(logging.StreamHandler): + def __init__(self, capacity): + super().__init__() + self.capacity = capacity + self.buffer = [] + self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }') + + def emit(self, record): + msg = self.format(record) + # self.buffer.append(json.loads(msg)) + self.buffer.append(msg) + if len(self.buffer) > self.capacity: + self.buffer.pop(0) + + def get(self): + return self.buffer + + from rich.theme import Theme + from rich.logging import RichHandler + from rich.console import Console + from rich.pretty import install as pretty_install + from rich.traceback import install as traceback_install + + level = logging.DEBUG + logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd` + console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ + "traceback.border": "black", + "traceback.border.syntax_error": "black", + "inspect.value.border": "black", + })) + logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null + pretty_install(console=console) + traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[]) + while logger.hasHandlers() and len(logger.handlers) > 0: + logger.removeHandler(logger.handlers[0]) + + # handlers + rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console) + rh.setLevel(level) + logger.addHandler(rh) + + rb = RingBuffer(100) # 100 entries default in log ring buffer + rb.setLevel(level) + logger.addHandler(rb) + logger.buffer = rb.buffer + + # overrides + logging.getLogger("urllib3").setLevel(logging.ERROR) + logging.getLogger("httpx").setLevel(logging.ERROR) + logging.getLogger("diffusers").setLevel(logging.ERROR) + logging.getLogger("torch").setLevel(logging.ERROR) + logging.getLogger("lycoris").handlers = logger.handlers + + +setup_logging() From 9992f7d8c0b35c6e076e4e3cdd726cbd63d56b85 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Dec 2023 20:54:32 -0800 Subject: [PATCH 4/8] Improve several log messages --- modules/GPTQ_loader.py | 2 +- modules/LoRA.py | 2 +- modules/extensions.py | 2 +- modules/models.py | 2 +- modules/shared.py | 32 ++++++++++++++++++-------------- modules/training.py | 18 +++++++++--------- server.py | 7 ++++++- 7 files changed, 37 insertions(+), 28 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 7dc20b0a..601c58f3 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -126,7 +126,7 @@ def load_quantized(model_name): path_to_model = Path(f'{shared.args.model_dir}/{model_name}') pt_path = find_quantized_model_file(model_name) if not pt_path: - logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...") + logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.") exit() else: logger.info(f"Found the following quantized model: {pt_path}") diff --git a/modules/LoRA.py b/modules/LoRA.py index dea476ad..97027eb4 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -138,7 +138,7 @@ def add_lora_transformers(lora_names): # Add a LoRA when another LoRA is already present if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys(): - logger.info(f"Adding the LoRA(s) named {added_set} to the model...") + logger.info(f"Adding the LoRA(s) named {added_set} to the model") for lora in added_set: shared.model.load_adapter(get_lora_path(lora), lora) diff --git a/modules/extensions.py b/modules/extensions.py index 25fcc0a3..2a3b0bb1 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -31,7 +31,7 @@ def load_extensions(): for i, name in enumerate(shared.args.extensions): if name in available_extensions: if name != 'api': - logger.info(f'Loading the extension "{name}"...') + logger.info(f'Loading the extension "{name}"') try: try: exec(f"import extensions.{name}.script") diff --git a/modules/models.py b/modules/models.py index 5a23f743..f37f3d60 100644 --- a/modules/models.py +++ b/modules/models.py @@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers() def load_model(model_name, loader=None): - logger.info(f"Loading {model_name}...") + logger.info(f"Loading {model_name}") t0 = time.time() shared.is_seq2seq = False diff --git a/modules/shared.py b/modules/shared.py index e0e77362..bcb20905 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -204,22 +204,26 @@ for arg in sys.argv[1:]: if hasattr(args, arg): provided_arguments.append(arg) -# Deprecation warnings deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast'] -for k in deprecated_args: - if getattr(args, k): - logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.') -# Security warnings -if args.trust_remote_code: - logger.warning('trust_remote_code is enabled. This is dangerous.') -if 'COLAB_GPU' not in os.environ and not args.nowebui: - if args.share: - logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") - if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): - logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.") - if args.multi_user: - logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') + +def do_cmd_flags_warnings(): + + # Deprecation warnings + for k in deprecated_args: + if getattr(args, k): + logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.') + + # Security warnings + if args.trust_remote_code: + logger.warning('trust_remote_code is enabled. This is dangerous.') + if 'COLAB_GPU' not in os.environ and not args.nowebui: + if args.share: + logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") + if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): + logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.") + if args.multi_user: + logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') def fix_loader_name(name): diff --git a/modules/training.py b/modules/training.py index ca1fffb3..b0e02400 100644 --- a/modules/training.py +++ b/modules/training.py @@ -249,7 +249,7 @@ def backup_adapter(input_folder): adapter_file = Path(f"{input_folder}/adapter_model.bin") if adapter_file.is_file(): - logger.info("Backing up existing LoRA adapter...") + logger.info("Backing up existing LoRA adapter") creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime) creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") @@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: # == Prep the dataset, format, etc == if raw_text_file not in ['None', '']: train_template["template_type"] = "raw_text" - logger.info("Loading raw text file dataset...") + logger.info("Loading raw text file dataset") fullpath = clean_path('training/datasets', f'{raw_text_file}') fullpath = Path(fullpath) if fullpath.is_dir(): @@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: prompt = generate_prompt(data_point) return tokenize(prompt, add_eos_token) - logger.info("Loading JSON datasets...") + logger.info("Loading JSON datasets") data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30)) @@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: # == Start prepping the model itself == if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): - logger.info("Getting model ready...") + logger.info("Getting model ready") prepare_model_for_kbit_training(shared.model) # base model is now frozen and should not be reused for any other LoRA training than this one shared.model_dirty_from_training = True - logger.info("Preparing for training...") + logger.info("Preparing for training") config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, @@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) try: - logger.info("Creating LoRA model...") + logger.info("Creating LoRA model") lora_model = get_peft_model(shared.model, config) if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): - logger.info("Loading existing LoRA data...") + logger.info("Loading existing LoRA data") state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True) set_peft_model_state_dict(lora_model, state_dict_peft) except: @@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: json.dump(train_template, file, indent=2) # == Main run and monitor loop == - logger.info("Starting training...") + logger.info("Starting training") yield "Starting..." lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) @@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: # Saving in the train thread might fail if an error occurs, so save here if so. if not tracked.did_save: - logger.info("Training complete, saving...") + logger.info("Training complete, saving") lora_model.save_pretrained(lora_file_path) if WANT_INTERRUPT: diff --git a/server.py b/server.py index ae0aed09..d5d11bc4 100644 --- a/server.py +++ b/server.py @@ -12,6 +12,7 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1' warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') +warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()') with RequestBlocker(): import gradio as gr @@ -54,6 +55,7 @@ from modules.models_settings import ( get_model_metadata, update_model_parameters ) +from modules.shared import do_cmd_flags_warnings from modules.utils import gradio @@ -170,6 +172,9 @@ def create_interface(): if __name__ == "__main__": + logger.info("Starting Text generation web UI") + do_cmd_flags_warnings() + # Load custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): @@ -180,7 +185,7 @@ if __name__ == "__main__": settings_file = Path('settings.json') if settings_file is not None: - logger.info(f"Loading settings from {settings_file}...") + logger.info(f"Loading settings from {settings_file}") file_contents = open(settings_file, 'r', encoding='utf-8').read() new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents) shared.settings.update(new_settings) From 366c93a008cb2c0cf23e88d6bdbb757626537688 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Dec 2023 21:03:20 -0800 Subject: [PATCH 5/8] Hide a warning --- server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server.py b/server.py index d5d11bc4..c53882f6 100644 --- a/server.py +++ b/server.py @@ -13,6 +13,7 @@ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()') +warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict') with RequestBlocker(): import gradio as gr From fb8ee9f7ff164aff95a747c73d3924e9613a76b8 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Dec 2023 21:32:58 -0800 Subject: [PATCH 6/8] Add a specific error if HQQ is missing --- modules/models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/models.py b/modules/models.py index f37f3d60..70316952 100644 --- a/modules/models.py +++ b/modules/models.py @@ -413,8 +413,12 @@ def ExLlamav2_HF_loader(model_name): def HQQ_loader(model_name): - from hqq.engine.hf import HQQModelForCausalLM - from hqq.core.quantize import HQQLinear, HQQBackend + try: + from hqq.engine.hf import HQQModelForCausalLM + from hqq.core.quantize import HQQLinear, HQQBackend + except ModuleNotFoundError: + logger.error("HQQ is not installed. You can install it with:\n\npip install hqq") + return None logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}") From 2289e9031e50326ddfae962db6e7f3cc6225077f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Dec 2023 21:33:49 -0800 Subject: [PATCH 7/8] Remove HQQ from requirements (after https://github.com/oobabooga/text-generation-webui/issues/4993) --- requirements.txt | 1 - requirements_amd.txt | 1 - requirements_amd_noavx2.txt | 1 - requirements_apple_intel.txt | 1 - requirements_apple_silicon.txt | 1 - requirements_cpu_only.txt | 1 - requirements_cpu_only_noavx2.txt | 1 - requirements_noavx2.txt | 1 - requirements_nowheels.txt | 1 - 9 files changed, 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9f2979ab..c7c6edaa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* diff --git a/requirements_amd.txt b/requirements_amd.txt index dcdf6986..6972b4aa 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64" gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index 9d8e195a..af58c5c5 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11; platform_system == "Windows" or python_version < "3.10" or python_version > "3.11" or platform_machine != "x86_64" gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index 03c3859f..a4147217 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index 1a775a54..d36c7d1b 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* diff --git a/requirements_cpu_only.txt b/requirements_cpu_only.txt index 3e5c524b..c6b1a254 100644 --- a/requirements_cpu_only.txt +++ b/requirements_cpu_only.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* diff --git a/requirements_cpu_only_noavx2.txt b/requirements_cpu_only_noavx2.txt index f972a794..c442e525 100644 --- a/requirements_cpu_only_noavx2.txt +++ b/requirements_cpu_only_noavx2.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index 08b85092..0d92f414 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* diff --git a/requirements_nowheels.txt b/requirements_nowheels.txt index cabccf7c..bc5cadcb 100644 --- a/requirements_nowheels.txt +++ b/requirements_nowheels.txt @@ -4,7 +4,6 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* -hqq==0.1.1 markdown numpy==1.24.* optimum==1.16.* From fadb295d4dbec37d806d3b8fd922cddb976adbdc Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Dec 2023 21:36:57 -0800 Subject: [PATCH 8/8] Lint --- extensions/coqui_tts/script.py | 1 - modules/logging_colors.py | 4 ++-- modules/models.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/extensions/coqui_tts/script.py b/extensions/coqui_tts/script.py index 3b241c58..26d6b76a 100644 --- a/extensions/coqui_tts/script.py +++ b/extensions/coqui_tts/script.py @@ -10,7 +10,6 @@ from TTS.api import TTS from TTS.utils.synthesizer import Synthesizer from modules import chat, shared, ui_chat -from modules.logging_colors import logger from modules.ui import create_refresh_button from modules.utils import gradio diff --git a/modules/logging_colors.py b/modules/logging_colors.py index ba760a2c..b9791e26 100644 --- a/modules/logging_colors.py +++ b/modules/logging_colors.py @@ -27,10 +27,10 @@ def setup_logging(): def get(self): return self.buffer - from rich.theme import Theme - from rich.logging import RichHandler from rich.console import Console + from rich.logging import RichHandler from rich.pretty import install as pretty_install + from rich.theme import Theme from rich.traceback import install as traceback_install level = logging.DEBUG diff --git a/modules/models.py b/modules/models.py index 70316952..cad6a165 100644 --- a/modules/models.py +++ b/modules/models.py @@ -414,8 +414,8 @@ def ExLlamav2_HF_loader(model_name): def HQQ_loader(model_name): try: + from hqq.core.quantize import HQQBackend, HQQLinear from hqq.engine.hf import HQQModelForCausalLM - from hqq.core.quantize import HQQLinear, HQQBackend except ModuleNotFoundError: logger.error("HQQ is not installed. You can install it with:\n\npip install hqq") return None