From ab1998146b3ba5045caf1cc520c81524835cb9b6 Mon Sep 17 00:00:00 2001 From: FartyPants Date: Tue, 27 Jun 2023 17:24:04 -0400 Subject: [PATCH] Training update - backup the existing adapter before training on top of it (#2902) --- modules/training.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/modules/training.py b/modules/training.py index c98cfc3b..855ed914 100644 --- a/modules/training.py +++ b/modules/training.py @@ -10,6 +10,10 @@ from pathlib import Path import gradio as gr import torch import transformers + +import shutil +from datetime import datetime + from datasets import Dataset, load_dataset from peft import ( LoraConfig, @@ -208,6 +212,35 @@ def clean_path(base_path: str, path: str): return f'{Path(base_path).absolute()}/{path}' +def backup_adapter(input_folder): + # Get the creation date of the file adapter_model.bin + try: + adapter_file = Path(f"{input_folder}/adapter_model.bin") + if adapter_file.is_file(): + + logger.info("Backing up existing LoRA adapter...") + creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime) + creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") + + # Create the new subfolder + subfolder_path = Path(f"{input_folder}/{creation_date_str}") + subfolder_path.mkdir(parents=True, exist_ok=True) + + # Check if the file already exists in the subfolder + backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin") + if backup_adapter_file.is_file(): + print(" - Backup already exists. Skipping backup process.") + return + + # Copy existing files to the new subfolder + existing_files = Path(input_folder).iterdir() + for file in existing_files: + if file.is_file(): + shutil.copy2(file, subfolder_path) + except Exception as e: + print("An error occurred in backup_adapter:", str(e)) + + def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float): if shared.args.monkey_patch: @@ -394,6 +427,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch task_type="CAUSAL_LM" ) + # == Backup the existing adapter == + if not always_override: + backup_adapter(lora_file_path) + try: logger.info("Creating LoRA model...") lora_model = get_peft_model(shared.model, config)