mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Improve several log messages
This commit is contained in:
parent
23818dc098
commit
9992f7d8c0
@ -126,7 +126,7 @@ def load_quantized(model_name):
|
|||||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
pt_path = find_quantized_model_file(model_name)
|
pt_path = find_quantized_model_file(model_name)
|
||||||
if not pt_path:
|
if not pt_path:
|
||||||
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
|
||||||
exit()
|
exit()
|
||||||
else:
|
else:
|
||||||
logger.info(f"Found the following quantized model: {pt_path}")
|
logger.info(f"Found the following quantized model: {pt_path}")
|
||||||
|
@ -138,7 +138,7 @@ def add_lora_transformers(lora_names):
|
|||||||
|
|
||||||
# Add a LoRA when another LoRA is already present
|
# Add a LoRA when another LoRA is already present
|
||||||
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
|
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
|
||||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
logger.info(f"Adding the LoRA(s) named {added_set} to the model")
|
||||||
for lora in added_set:
|
for lora in added_set:
|
||||||
shared.model.load_adapter(get_lora_path(lora), lora)
|
shared.model.load_adapter(get_lora_path(lora), lora)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def load_extensions():
|
|||||||
for i, name in enumerate(shared.args.extensions):
|
for i, name in enumerate(shared.args.extensions):
|
||||||
if name in available_extensions:
|
if name in available_extensions:
|
||||||
if name != 'api':
|
if name != 'api':
|
||||||
logger.info(f'Loading the extension "{name}"...')
|
logger.info(f'Loading the extension "{name}"')
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
exec(f"import extensions.{name}.script")
|
exec(f"import extensions.{name}.script")
|
||||||
|
@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers()
|
|||||||
|
|
||||||
|
|
||||||
def load_model(model_name, loader=None):
|
def load_model(model_name, loader=None):
|
||||||
logger.info(f"Loading {model_name}...")
|
logger.info(f"Loading {model_name}")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.is_seq2seq = False
|
shared.is_seq2seq = False
|
||||||
|
@ -204,22 +204,26 @@ for arg in sys.argv[1:]:
|
|||||||
if hasattr(args, arg):
|
if hasattr(args, arg):
|
||||||
provided_arguments.append(arg)
|
provided_arguments.append(arg)
|
||||||
|
|
||||||
# Deprecation warnings
|
|
||||||
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
|
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
|
||||||
for k in deprecated_args:
|
|
||||||
if getattr(args, k):
|
|
||||||
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
|
|
||||||
|
|
||||||
# Security warnings
|
|
||||||
if args.trust_remote_code:
|
def do_cmd_flags_warnings():
|
||||||
logger.warning('trust_remote_code is enabled. This is dangerous.')
|
|
||||||
if 'COLAB_GPU' not in os.environ and not args.nowebui:
|
# Deprecation warnings
|
||||||
if args.share:
|
for k in deprecated_args:
|
||||||
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
if getattr(args, k):
|
||||||
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
|
||||||
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
|
||||||
if args.multi_user:
|
# Security warnings
|
||||||
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
if args.trust_remote_code:
|
||||||
|
logger.warning('trust_remote_code is enabled. This is dangerous.')
|
||||||
|
if 'COLAB_GPU' not in os.environ and not args.nowebui:
|
||||||
|
if args.share:
|
||||||
|
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
||||||
|
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
||||||
|
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
||||||
|
if args.multi_user:
|
||||||
|
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
||||||
|
|
||||||
|
|
||||||
def fix_loader_name(name):
|
def fix_loader_name(name):
|
||||||
|
@ -249,7 +249,7 @@ def backup_adapter(input_folder):
|
|||||||
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
||||||
if adapter_file.is_file():
|
if adapter_file.is_file():
|
||||||
|
|
||||||
logger.info("Backing up existing LoRA adapter...")
|
logger.info("Backing up existing LoRA adapter")
|
||||||
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
|
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
|
||||||
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
|
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
|
||||||
|
|
||||||
@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
if raw_text_file not in ['None', '']:
|
||||||
train_template["template_type"] = "raw_text"
|
train_template["template_type"] = "raw_text"
|
||||||
logger.info("Loading raw text file dataset...")
|
logger.info("Loading raw text file dataset")
|
||||||
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
||||||
fullpath = Path(fullpath)
|
fullpath = Path(fullpath)
|
||||||
if fullpath.is_dir():
|
if fullpath.is_dir():
|
||||||
@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
prompt = generate_prompt(data_point)
|
prompt = generate_prompt(data_point)
|
||||||
return tokenize(prompt, add_eos_token)
|
return tokenize(prompt, add_eos_token)
|
||||||
|
|
||||||
logger.info("Loading JSON datasets...")
|
logger.info("Loading JSON datasets")
|
||||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||||
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||||
|
|
||||||
@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
|
|
||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
logger.info("Getting model ready...")
|
logger.info("Getting model ready")
|
||||||
prepare_model_for_kbit_training(shared.model)
|
prepare_model_for_kbit_training(shared.model)
|
||||||
|
|
||||||
# base model is now frozen and should not be reused for any other LoRA training than this one
|
# base model is now frozen and should not be reused for any other LoRA training than this one
|
||||||
shared.model_dirty_from_training = True
|
shared.model_dirty_from_training = True
|
||||||
|
|
||||||
logger.info("Preparing for training...")
|
logger.info("Preparing for training")
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
|
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Creating LoRA model...")
|
logger.info("Creating LoRA model")
|
||||||
lora_model = get_peft_model(shared.model, config)
|
lora_model = get_peft_model(shared.model, config)
|
||||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
||||||
logger.info("Loading existing LoRA data...")
|
logger.info("Loading existing LoRA data")
|
||||||
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
|
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
|
||||||
set_peft_model_state_dict(lora_model, state_dict_peft)
|
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||||
except:
|
except:
|
||||||
@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
json.dump(train_template, file, indent=2)
|
json.dump(train_template, file, indent=2)
|
||||||
|
|
||||||
# == Main run and monitor loop ==
|
# == Main run and monitor loop ==
|
||||||
logger.info("Starting training...")
|
logger.info("Starting training")
|
||||||
yield "Starting..."
|
yield "Starting..."
|
||||||
|
|
||||||
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
||||||
@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
|
|
||||||
# Saving in the train thread might fail if an error occurs, so save here if so.
|
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||||
if not tracked.did_save:
|
if not tracked.did_save:
|
||||||
logger.info("Training complete, saving...")
|
logger.info("Training complete, saving")
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
|
@ -12,6 +12,7 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
|||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
|
||||||
|
|
||||||
with RequestBlocker():
|
with RequestBlocker():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -54,6 +55,7 @@ from modules.models_settings import (
|
|||||||
get_model_metadata,
|
get_model_metadata,
|
||||||
update_model_parameters
|
update_model_parameters
|
||||||
)
|
)
|
||||||
|
from modules.shared import do_cmd_flags_warnings
|
||||||
from modules.utils import gradio
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
|
||||||
@ -170,6 +172,9 @@ def create_interface():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
logger.info("Starting Text generation web UI")
|
||||||
|
do_cmd_flags_warnings()
|
||||||
|
|
||||||
# Load custom settings
|
# Load custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||||
@ -180,7 +185,7 @@ if __name__ == "__main__":
|
|||||||
settings_file = Path('settings.json')
|
settings_file = Path('settings.json')
|
||||||
|
|
||||||
if settings_file is not None:
|
if settings_file is not None:
|
||||||
logger.info(f"Loading settings from {settings_file}...")
|
logger.info(f"Loading settings from {settings_file}")
|
||||||
file_contents = open(settings_file, 'r', encoding='utf-8').read()
|
file_contents = open(settings_file, 'r', encoding='utf-8').read()
|
||||||
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
|
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
|
||||||
shared.settings.update(new_settings)
|
shared.settings.update(new_settings)
|
||||||
|
Loading…
Reference in New Issue
Block a user