Merge remote-tracking branch 'refs/remotes/origin/main'

This commit is contained in:
oobabooga 2023-06-27 18:48:17 -03:00
commit c95009d2bd

View File

@ -10,6 +10,10 @@ from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
import transformers import transformers
import shutil
from datetime import datetime
from datasets import Dataset, load_dataset from datasets import Dataset, load_dataset
from peft import ( from peft import (
LoraConfig, LoraConfig,
@ -208,6 +212,35 @@ def clean_path(base_path: str, path: str):
return f'{Path(base_path).absolute()}/{path}' return f'{Path(base_path).absolute()}/{path}'
def backup_adapter(input_folder):
# Get the creation date of the file adapter_model.bin
try:
adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file():
logger.info("Backing up existing LoRA adapter...")
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
# Create the new subfolder
subfolder_path = Path(f"{input_folder}/{creation_date_str}")
subfolder_path.mkdir(parents=True, exist_ok=True)
# Check if the file already exists in the subfolder
backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin")
if backup_adapter_file.is_file():
print(" - Backup already exists. Skipping backup process.")
return
# Copy existing files to the new subfolder
existing_files = Path(input_folder).iterdir()
for file in existing_files:
if file.is_file():
shutil.copy2(file, subfolder_path)
except Exception as e:
print("An error occurred in backup_adapter:", str(e))
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float): def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
if shared.args.monkey_patch: if shared.args.monkey_patch:
@ -394,6 +427,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
task_type="CAUSAL_LM" task_type="CAUSAL_LM"
) )
# == Backup the existing adapter ==
if not always_override:
backup_adapter(lora_file_path)
try: try:
logger.info("Creating LoRA model...") logger.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)