From 95d04d6a8dfe2dd6bf3ce83a2bf3799ff5232f0b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 3 May 2023 21:43:17 -0300 Subject: [PATCH] Better warning messages --- modules/GPTQ_loader.py | 24 +++---- modules/LoRA.py | 5 +- modules/chat.py | 9 +-- modules/extensions.py | 9 +-- modules/llama_attn_hijack.py | 27 +++----- modules/logging_colors.py | 109 ++++++++++++++++++++++++++++++ modules/models.py | 32 +++++---- modules/monkey_patch_gptq_lora.py | 1 + modules/shared.py | 9 +-- modules/text_generation.py | 3 +- modules/training.py | 35 +++++----- modules/ui.py | 1 + server.py | 13 ++-- 13 files changed, 194 insertions(+), 83 deletions(-) create mode 100644 modules/logging_colors.py diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 58c4a0bb..df18df56 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -1,4 +1,5 @@ import inspect +import logging import re import sys from pathlib import Path @@ -71,7 +72,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc del layers - print('Loading model ...') if checkpoint.endswith('.safetensors'): from safetensors.torch import load_file as safe_load model.load_state_dict(safe_load(checkpoint), strict=False) @@ -90,8 +90,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc quant.autotune_warmup_fused(model) model.seqlen = 2048 - print('Done.') - return model @@ -119,11 +117,13 @@ def find_quantized_model_file(model_name): if len(found_pts) > 0: if len(found_pts) > 1: - print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.') + logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.') + pt_path = found_pts[-1] elif len(found_safetensors) > 0: if len(found_pts) > 1: - print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.') + logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.') + pt_path = found_safetensors[-1] return pt_path @@ -142,8 +142,7 @@ def load_quantized(model_name): elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])): model_type = 'gptj' else: - print("Can't determine model type from model name. Please specify it manually using --model_type " - "argument") + logging.error("Can't determine model type from model name. Please specify it manually using --model_type argument") exit() else: model_type = shared.args.model_type.lower() @@ -153,20 +152,21 @@ def load_quantized(model_name): load_quant = llama_inference_offload.load_quant elif model_type in ('llama', 'opt', 'gptj'): if shared.args.pre_layer: - print("Warning: ignoring --pre_layer because it only works for llama model type.") + logging.warning("Ignoring --pre_layer because it only works for llama model type.") + load_quant = _load_quant else: - print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") + logging.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") exit() # Find the quantized model weights file (.pt/.safetensors) path_to_model = Path(f'{shared.args.model_dir}/{model_name}') pt_path = find_quantized_model_file(model_name) if not pt_path: - print("Could not find the quantized model in .pt or .safetensors format, exiting...") + logging.error("Could not find the quantized model in .pt or .safetensors format, exiting...") exit() else: - print(f"Found the following quantized model: {pt_path}") + logging.info(f"Found the following quantized model: {pt_path}") # qwopqwop200's offload if model_type == 'llama' and shared.args.pre_layer: @@ -188,7 +188,7 @@ def load_quantized(model_name): max_memory = accelerate.utils.get_balanced_memory(model) device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) - print("Using the following device map for the quantized model:", device_map) + logging.info("Using the following device map for the quantized model:", device_map) # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True) diff --git a/modules/LoRA.py b/modules/LoRA.py index f734f3cd..f996289a 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path import torch @@ -18,7 +19,7 @@ def add_lora_to_model(lora_names): # Add a LoRA when another LoRA is already present if len(removed_set) == 0 and len(prior_set) > 0: - print(f"Adding the LoRA(s) named {added_set} to the model...") + logging.info(f"Adding the LoRA(s) named {added_set} to the model...") for lora in added_set: shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) @@ -29,7 +30,7 @@ def add_lora_to_model(lora_names): shared.model.disable_adapter() if len(lora_names) > 0: - print("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) + logging.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) params = {} if not shared.args.cpu: params['dtype'] = shared.model.dtype diff --git a/modules/chat.py b/modules/chat.py index fb7c2dee..c17f4545 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -3,6 +3,7 @@ import base64 import copy import io import json +import logging import re from datetime import datetime from pathlib import Path @@ -138,7 +139,7 @@ def extract_message_from_reply(reply, state): def chatbot_wrapper(text, state, regenerate=False, _continue=False): if shared.model_name == 'None' or shared.model is None: - print("No model is loaded! Select one in the Model tab.") + logging.error("No model is loaded! Select one in the Model tab.") yield shared.history['visible'] return @@ -216,7 +217,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): def impersonate_wrapper(text, state): if shared.model_name == 'None' or shared.model is None: - print("No model is loaded! Select one in the Model tab.") + logging.error("No model is loaded! Select one in the Model tab.") yield '' return @@ -523,7 +524,7 @@ def upload_character(json_file, img, tavern=False): img = Image.open(io.BytesIO(img)) img.save(Path(f'characters/{outfile_name}.png')) - print(f'New character saved to "characters/{outfile_name}.json".') + logging.info(f'New character saved to "characters/{outfile_name}.json".') return outfile_name @@ -547,6 +548,6 @@ def upload_your_profile_picture(img, name1, name2, mode): else: img = make_thumbnail(img) img.save(Path('cache/pfp_me.png')) - print('Profile picture saved to "cache/pfp_me.png"') + logging.info('Profile picture saved to "cache/pfp_me.png"') return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) diff --git a/modules/extensions.py b/modules/extensions.py index 1a6a2d9b..1bb36d52 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,3 +1,4 @@ +import logging import traceback from functools import partial @@ -28,7 +29,7 @@ def load_extensions(): for i, name in enumerate(shared.args.extensions): if name in available_extensions: if name != 'api': - print(f'Loading the extension "{name}"... ', end='') + logging.info(f'Loading the extension "{name}"...') try: exec(f"import extensions.{name}.script") extension = getattr(extensions, name).script @@ -38,12 +39,8 @@ def load_extensions(): extension.setup() state[name] = [True, i] - if name != 'api': - print('Ok.') except: - if name != 'api': - print('Fail.') - + logging.error('Failed to load the extension "{name}".') traceback.print_exc() diff --git a/modules/llama_attn_hijack.py b/modules/llama_attn_hijack.py index f5c5c92e..e953f523 100644 --- a/modules/llama_attn_hijack.py +++ b/modules/llama_attn_hijack.py @@ -1,29 +1,28 @@ +import logging import math import sys +from typing import Optional, Tuple + import torch import torch.nn as nn import transformers.models.llama.modeling_llama -from typing import Optional -from typing import Tuple - import modules.shared as shared - if shared.args.xformers: try: import xformers.ops except Exception: - print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr) + logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr) def hijack_llama_attention(): if shared.args.xformers: transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward - print("Replaced attention with xformers_attention") + logging.info("Replaced attention with xformers_attention") elif shared.args.sdp_attention: transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward - print("Replaced attention with sdp_attention") + logging.info("Replaced attention with sdp_attention") def xformers_forward( @@ -55,16 +54,14 @@ def xformers_forward( past_key_value = (key_states, value_states) if use_cache else None - #We only apply xformers optimizations if we don't need to output the whole attention matrix + # We only apply xformers optimizations if we don't need to output the whole attention matrix if not output_attentions: - dtype = query_states.dtype - query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - - #This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. - #We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + + # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: # input and output should be of form (bsz, q_len, num_heads, head_dim) attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) @@ -102,9 +99,7 @@ def xformers_forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value @@ -137,7 +132,7 @@ def sdp_attention_forward( past_key_value = (key_states, value_states) if use_cache else None - #We only apply sdp attention if we don't need to output the whole attention matrix + # We only apply sdp attention if we don't need to output the whole attention matrix if not output_attentions: attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False) attn_weights = None diff --git a/modules/logging_colors.py b/modules/logging_colors.py new file mode 100644 index 00000000..b4fdf1be --- /dev/null +++ b/modules/logging_colors.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +# encoding: utf-8 +import logging +# now we patch Python code to add color support to logging.StreamHandler + + +def add_coloring_to_emit_windows(fn): + # add methods we need to the class + def _out_handle(self): + import ctypes + return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE) + out_handle = property(_out_handle) + + def _set_color(self, code): + import ctypes + # Constants from the Windows API + self.STD_OUTPUT_HANDLE = -11 + hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE) + ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code) + + setattr(logging.StreamHandler, '_set_color', _set_color) + + def new(*args): + FOREGROUND_BLUE = 0x0001 # text color contains blue. + FOREGROUND_GREEN = 0x0002 # text color contains green. + FOREGROUND_RED = 0x0004 # text color contains red. + FOREGROUND_INTENSITY = 0x0008 # text color is intensified. + FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED + # winbase.h + # STD_INPUT_HANDLE = -10 + # STD_OUTPUT_HANDLE = -11 + # STD_ERROR_HANDLE = -12 + + # wincon.h + # FOREGROUND_BLACK = 0x0000 + FOREGROUND_BLUE = 0x0001 + FOREGROUND_GREEN = 0x0002 + # FOREGROUND_CYAN = 0x0003 + FOREGROUND_RED = 0x0004 + FOREGROUND_MAGENTA = 0x0005 + FOREGROUND_YELLOW = 0x0006 + # FOREGROUND_GREY = 0x0007 + FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified. + + # BACKGROUND_BLACK = 0x0000 + # BACKGROUND_BLUE = 0x0010 + # BACKGROUND_GREEN = 0x0020 + # BACKGROUND_CYAN = 0x0030 + # BACKGROUND_RED = 0x0040 + # BACKGROUND_MAGENTA = 0x0050 + BACKGROUND_YELLOW = 0x0060 + # BACKGROUND_GREY = 0x0070 + BACKGROUND_INTENSITY = 0x0080 # background color is intensified. + + levelno = args[1].levelno + if (levelno >= 50): + color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY + elif (levelno >= 40): + color = FOREGROUND_RED | FOREGROUND_INTENSITY + elif (levelno >= 30): + color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY + elif (levelno >= 20): + color = FOREGROUND_GREEN + elif (levelno >= 10): + color = FOREGROUND_MAGENTA + else: + color = FOREGROUND_WHITE + args[0]._set_color(color) + + ret = fn(*args) + args[0]._set_color(FOREGROUND_WHITE) + # print "after" + return ret + return new + + +def add_coloring_to_emit_ansi(fn): + # add methods we need to the class + def new(*args): + levelno = args[1].levelno + if (levelno >= 50): + color = '\x1b[31m' # red + elif (levelno >= 40): + color = '\x1b[31m' # red + elif (levelno >= 30): + color = '\x1b[33m' # yellow + elif (levelno >= 20): + color = '\x1b[32m' # green + elif (levelno >= 10): + color = '\x1b[35m' # pink + else: + color = '\x1b[0m' # normal + args[1].msg = color + args[1].msg + '\x1b[0m' # normal + # print "after" + return fn(*args) + return new + + +import platform +if platform.system() == 'Windows': + # Windows does not support ANSI escapes and we are using API calls to set the console color + logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit) +else: + # all non-Windows platforms are supporting ANSI escapes so we use them + logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit) + # log = logging.getLogger() + # log.addFilter(log_filter()) + # //hdlr = logging.StreamHandler() + # //hdlr.setFormatter(formatter()) diff --git a/modules/models.py b/modules/models.py index e2193e6c..6f4257ec 100644 --- a/modules/models.py +++ b/modules/models.py @@ -1,5 +1,6 @@ import gc import json +import logging import os import re import time @@ -65,7 +66,7 @@ def find_model_type(model_name): def load_model(model_name): - print(f"Loading {model_name}...") + logging.info(f"Loading {model_name}...") t0 = time.time() shared.model_type = find_model_type(model_name) @@ -116,7 +117,7 @@ def load_model(model_name): model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference - print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") + logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") # RMKV model (not on HuggingFace) elif shared.model_type == 'rwkv': @@ -137,7 +138,7 @@ def load_model(model_name): else: model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0] - print(f"llama.cpp weights detected: {model_file}\n") + logging.info(f"llama.cpp weights detected: {model_file}\n") model, tokenizer = LlamaCppModel.from_pretrained(model_file) return model, tokenizer @@ -146,7 +147,7 @@ def load_model(model_name): # Monkey patch if shared.args.monkey_patch: - print("Warning: applying the monkey patch for using LoRAs in 4-bit mode.\nIt may cause undefined behavior outside its intended scope.") + logging.warning("Warning: applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.") from modules.monkey_patch_gptq_lora import load_model_llama model, _ = load_model_llama(model_name) @@ -161,7 +162,7 @@ def load_model(model_name): else: params = {"low_cpu_mem_usage": True} if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): - print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n") + logging.warning("Warning: torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.") shared.args.cpu = True if shared.args.cpu: @@ -184,6 +185,7 @@ def load_model(model_name): max_memory = {} for i in range(len(memory_map)): max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i] + max_memory['cpu'] = max_cpu_memory params['max_memory'] = max_memory elif shared.args.auto_devices: @@ -191,9 +193,9 @@ def load_model(model_name): suggestion = round((total_mem - 1000) / 1000) * 1000 if total_mem - suggestion < 800: suggestion -= 1000 - suggestion = int(round(suggestion / 1000)) - print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") + suggestion = int(round(suggestion / 1000)) + logging.warning(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'} params['max_memory'] = max_memory @@ -201,11 +203,11 @@ def load_model(model_name): params["offload_folder"] = shared.args.disk_cache_dir checkpoint = Path(f'{shared.args.model_dir}/{model_name}') - if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': config = AutoConfig.from_pretrained(checkpoint) with init_empty_weights(): model = LoaderClass.from_config(config) + model.tie_weights() params['device_map'] = infer_auto_device_map( model, @@ -230,7 +232,7 @@ def load_model(model_name): if shared.model_type != 'llava': for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: if p.exists(): - print(f"Loading the universal LLaMA tokenizer from {p}...") + logging.info(f"Loading the universal LLaMA tokenizer from {p}...") tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True) break @@ -247,7 +249,7 @@ def load_model(model_name): else: tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code) - print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") + logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer @@ -276,20 +278,20 @@ def load_soft_prompt(name): zf.extract('tensor.npy') zf.extract('meta.json') j = json.loads(open('meta.json', 'r').read()) - print(f"\nLoading the softprompt \"{name}\".") + logging.info(f"\nLoading the softprompt \"{name}\".") for field in j: if field != 'name': if type(j[field]) is list: - print(f"{field}: {', '.join(j[field])}") + logging.info(f"{field}: {', '.join(j[field])}") else: - print(f"{field}: {j[field]}") - print() + logging.info(f"{field}: {j[field]}") + logging.info() tensor = np.load('tensor.npy') Path('tensor.npy').unlink() Path('meta.json').unlink() + tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype) tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1])) - shared.soft_prompt = True shared.soft_prompt_tensor = tensor diff --git a/modules/monkey_patch_gptq_lora.py b/modules/monkey_patch_gptq_lora.py index 45318df6..a37e7906 100644 --- a/modules/monkey_patch_gptq_lora.py +++ b/modules/monkey_patch_gptq_lora.py @@ -17,6 +17,7 @@ from modules.GPTQ_loader import find_quantized_model_file replace_peft_model_with_gptq_lora_model() + def load_model_llama(model_name): config_path = str(Path(f'{shared.args.model_dir}/{model_name}')) model_path = str(find_quantized_model_file(model_name)) diff --git a/modules/shared.py b/modules/shared.py index ef96368f..8ce1ded2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,4 +1,5 @@ import argparse +import logging from pathlib import Path import yaml @@ -170,19 +171,19 @@ args_defaults = parser.parse_args([]) deprecated_dict = {} for k in deprecated_dict: if getattr(args, k) != deprecated_dict[k][1]: - print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.\n") + logging.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.") setattr(args, deprecated_dict[k][0], getattr(args, k)) # Deprecation warnings for parameters that have been removed if args.cai_chat: - print("Warning: --cai-chat is deprecated. Use --chat instead.\n") + logging.warning("--cai-chat is deprecated. Use --chat instead.") args.chat = True # Security warnings if args.trust_remote_code: - print("Warning: trust_remote_code is enabled. This is dangerous.\n") + logging.warning("trust_remote_code is enabled. This is dangerous.") if args.share: - print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n") + logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.") # Activating the API extension if args.api or args.public_api: diff --git a/modules/text_generation.py b/modules/text_generation.py index 42a5c394..2a994a6a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,4 +1,5 @@ import ast +import logging import random import re import time @@ -175,7 +176,7 @@ def get_generate_params(state): def generate_reply(question, state, eos_token=None, stopping_strings=[]): if shared.model_name == 'None' or shared.model is None: - print("No model is loaded! Select one in the Model tab.") + logging.error("No model is loaded! Select one in the Model tab.") yield formatted_outputs(question, shared.model_name) return diff --git a/modules/training.py b/modules/training.py index 9789c1c2..82e42f4d 100644 --- a/modules/training.py +++ b/modules/training.py @@ -1,4 +1,5 @@ import json +import logging import math import sys import threading @@ -40,7 +41,6 @@ WANT_INTERRUPT = False PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"] - def get_datasets(path: str, ext: str): return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower) @@ -123,7 +123,7 @@ def create_train_interface(): stop_evaluation = gr.Button("Interrupt") with gr.Column(): - evaluation_log = gr.Markdown(value = '') + evaluation_log = gr.Markdown(value='') evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True) save_comments = gr.Button('Save comments') @@ -220,13 +220,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch if model_type == "PeftModelForCausalLM": if len(shared.args.lora_names) > 0: yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" - print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.") + logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.") else: yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" - print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.") + logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.") else: yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" - print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})") + logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})") + time.sleep(5) if shared.args.wbits > 0 and not shared.args.monkey_patch: @@ -235,7 +236,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch elif not shared.args.load_in_8bit and shared.args.wbits <= 0: yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*" - print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.") + logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.") time.sleep(2) # Give it a moment for the message to show in UI before continuing if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: @@ -255,7 +256,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch # == Prep the dataset, format, etc == if raw_text_file not in ['None', '']: - print("Loading raw text file dataset...") + logging.info("Loading raw text file dataset...") with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: raw_text = file.read() @@ -299,7 +300,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch prompt = generate_prompt(data_point) return tokenize(prompt) - print("Loading JSON datasets...") + logging.info("Loading JSON datasets...") data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) train_data = data['train'].map(generate_and_tokenize_prompt) @@ -311,10 +312,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch # == Start prepping the model itself == if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): - print("Getting model ready...") + logging.info("Getting model ready...") prepare_model_for_int8_training(shared.model) - print("Prepping for training...") + logging.info("Prepping for training...") config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, @@ -325,10 +326,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch ) try: - print("Creating LoRA model...") + logging.info("Creating LoRA model...") lora_model = get_peft_model(shared.model, config) if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): - print("Loading existing LoRA data...") + logging.info("Loading existing LoRA data...") state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin") set_peft_model_state_dict(lora_model, state_dict_peft) except: @@ -406,7 +407,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch json.dump({x: vars[x] for x in PARAMETERS}, file) # == Main run and monitor loop == - print("Starting training...") + logging.info("Starting training...") yield "Starting..." if WANT_INTERRUPT: yield "Interrupted before start." @@ -416,7 +417,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch trainer.train() # Note: save in the thread in case the gradio thread breaks (eg browser closed) lora_model.save_pretrained(lora_file_path) - print("LoRA training run is completed and saved.") + logging.info("LoRA training run is completed and saved.") tracked.did_save = True thread = threading.Thread(target=threaded_run) @@ -448,14 +449,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch # Saving in the train thread might fail if an error occurs, so save here if so. if not tracked.did_save: - print("Training complete, saving...") + logging.info("Training complete, saving...") lora_model.save_pretrained(lora_file_path) if WANT_INTERRUPT: - print("Training interrupted.") + logging.info("Training interrupted.") yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`" else: - print("Training complete!") + logging.info("Training complete!") yield f"Done! LoRA saved to `{lora_file_path}`" diff --git a/modules/ui.py b/modules/ui.py index c40e596b..53f6a247 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -25,6 +25,7 @@ theme = gr.themes.Default( background_fill_secondary='#eaeaea' ) + def list_model_elements(): elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer'] for i in range(torch.cuda.device_count()): diff --git a/server.py b/server.py index f8d3b2bf..e940d366 100644 --- a/server.py +++ b/server.py @@ -1,14 +1,17 @@ +import logging import os import requests import warnings +import modules.logging_colors os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['BITSANDBYTES_NOWELCOME'] = '1' warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') +logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO) # This is a hack to prevent Gradio from phoning home when it gets imported def my_get(url, **kwargs): - print('Gradio HTTP request redirected to localhost :)') + logging.info('Gradio HTTP request redirected to localhost :)') kwargs.setdefault('allow_redirects', True) return requests.api.request('get', 'http://127.0.0.1/', **kwargs) @@ -17,9 +20,8 @@ requests.get = my_get import gradio as gr requests.get = original_get -# This fixes LaTeX rendering on some systems import matplotlib -matplotlib.use('Agg') +matplotlib.use('Agg') # This fixes LaTeX rendering on some systems import importlib import io @@ -39,7 +41,6 @@ import psutil import torch import yaml from PIL import Image - import modules.extensions as extensions_module from modules import chat, shared, training, ui from modules.html_generator import chat_html_wrapper @@ -860,7 +861,7 @@ if __name__ == "__main__": elif Path('settings.json').exists(): settings_file = Path('settings.json') if settings_file is not None: - print(f"Loading settings from {settings_file}...") + logging.info(f"Loading settings from {settings_file}...") new_settings = json.loads(open(settings_file, 'r').read()) for item in new_settings: shared.settings[item] = new_settings[item] @@ -891,7 +892,7 @@ if __name__ == "__main__": # Select the model from a command-line menu elif shared.args.model_menu: if len(available_models) == 0: - print('No models are available! Please download at least one.') + logging.error('No models are available! Please download at least one.') sys.exit(0) else: print('The following models are available:\n')