Added saving of training logs to training_log.json (#2769)

This commit is contained in:
FartyPants 2023-06-19 23:47:36 -04:00 committed by GitHub
parent 017884132f
commit ce86f726e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -11,7 +11,7 @@ import gradio as gr
import torch
import transformers
from datasets import Dataset, load_dataset
from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
from peft import (LoraConfig, get_peft_model, prepare_model_for_kbit_training,
set_peft_model_state_dict)
from modules import shared, ui, utils
@ -38,6 +38,7 @@ except:
"GPTNeoXForCausalLM": "gpt_neox"
}
train_log = {}
WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after"]
@ -357,7 +358,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...")
prepare_model_for_int8_training(shared.model)
prepare_model_for_kbit_training(shared.model)
logger.info("Prepping for training...")
config = LoraConfig(
@ -406,12 +407,19 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
control.should_training_stop = True
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
# Save log
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
tracked.current_steps += 1
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
train_log.update(logs)
trainer = transformers.Trainer(
model=lora_model,
@ -448,7 +456,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Save parameters for reuse ==
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
vars = locals()
json.dump({x: vars[x] for x in PARAMETERS}, file)
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
# == Main run and monitor loop ==
logger.info("Starting training...")
@ -462,7 +470,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
logger.info("LoRA training run is completed and saved.")
tracked.did_save = True
# Save log
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
thread = threading.Thread(target=threaded_run)
thread.start()