diff --git a/modules/training.py b/modules/training.py index 855ed914..7ead4b47 100644 --- a/modules/training.py +++ b/modules/training.py @@ -240,6 +240,21 @@ def backup_adapter(input_folder): except Exception as e: print("An error occurred in backup_adapter:", str(e)) +def calc_trainable_parameters(model): + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params,all_param + def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float): @@ -431,6 +446,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch if not always_override: backup_adapter(lora_file_path) + # == get model trainable params + model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) + try: logger.info("Creating LoRA model...") lora_model = get_peft_model(shared.model, config) @@ -540,6 +558,12 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch logger.info("Starting training...") yield "Starting..." + lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) + + if lora_all_param>0: + print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})") + + train_log.update({"base_model_name": shared.model_name}) train_log.update({"base_model_class": shared.model.__class__.__name__}) train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})