mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Training update - backup the existing adapter before training on top of it (#2902)
This commit is contained in:
parent
40bbd53640
commit
ab1998146b
@ -10,6 +10,10 @@ from pathlib import Path
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import (
|
from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
@ -208,6 +212,35 @@ def clean_path(base_path: str, path: str):
|
|||||||
return f'{Path(base_path).absolute()}/{path}'
|
return f'{Path(base_path).absolute()}/{path}'
|
||||||
|
|
||||||
|
|
||||||
|
def backup_adapter(input_folder):
|
||||||
|
# Get the creation date of the file adapter_model.bin
|
||||||
|
try:
|
||||||
|
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
||||||
|
if adapter_file.is_file():
|
||||||
|
|
||||||
|
logger.info("Backing up existing LoRA adapter...")
|
||||||
|
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
|
||||||
|
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
|
||||||
|
|
||||||
|
# Create the new subfolder
|
||||||
|
subfolder_path = Path(f"{input_folder}/{creation_date_str}")
|
||||||
|
subfolder_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Check if the file already exists in the subfolder
|
||||||
|
backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin")
|
||||||
|
if backup_adapter_file.is_file():
|
||||||
|
print(" - Backup already exists. Skipping backup process.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Copy existing files to the new subfolder
|
||||||
|
existing_files = Path(input_folder).iterdir()
|
||||||
|
for file in existing_files:
|
||||||
|
if file.is_file():
|
||||||
|
shutil.copy2(file, subfolder_path)
|
||||||
|
except Exception as e:
|
||||||
|
print("An error occurred in backup_adapter:", str(e))
|
||||||
|
|
||||||
|
|
||||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
|
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
|
||||||
|
|
||||||
if shared.args.monkey_patch:
|
if shared.args.monkey_patch:
|
||||||
@ -394,6 +427,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
task_type="CAUSAL_LM"
|
task_type="CAUSAL_LM"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# == Backup the existing adapter ==
|
||||||
|
if not always_override:
|
||||||
|
backup_adapter(lora_file_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Creating LoRA model...")
|
logger.info("Creating LoRA model...")
|
||||||
lora_model = get_peft_model(shared.model, config)
|
lora_model = get_peft_model(shared.model, config)
|
||||||
|
Loading…
Reference in New Issue
Block a user