diff --git a/modules/training.py b/modules/training.py index 75ba82ca..d039e807 100644 --- a/modules/training.py +++ b/modules/training.py @@ -11,7 +11,7 @@ import gradio as gr import torch import transformers from datasets import Dataset, load_dataset -from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training, +from peft import (LoraConfig, get_peft_model, prepare_model_for_kbit_training, set_peft_model_state_dict) from modules import shared, ui, utils @@ -38,6 +38,7 @@ except: "GPTNeoXForCausalLM": "gpt_neox" } +train_log = {} WANT_INTERRUPT = False PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after"] @@ -357,7 +358,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch # == Start prepping the model itself == if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): logger.info("Getting model ready...") - prepare_model_for_int8_training(shared.model) + prepare_model_for_kbit_training(shared.model) logger.info("Prepping for training...") config = LoraConfig( @@ -406,12 +407,19 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch control.should_training_stop = True elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0: lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/") + # Save log + with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file: + json.dump(train_log, file, indent=2) + def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): tracked.current_steps += 1 if WANT_INTERRUPT: control.should_epoch_stop = True control.should_training_stop = True + + def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs): + train_log.update(logs) trainer = transformers.Trainer( model=lora_model, @@ -448,7 +456,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch # == Save parameters for reuse == with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file: vars = locals() - json.dump({x: vars[x] for x in PARAMETERS}, file) + json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2) # == Main run and monitor loop == logger.info("Starting training...") @@ -462,7 +470,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch # Note: save in the thread in case the gradio thread breaks (eg browser closed) lora_model.save_pretrained(lora_file_path) logger.info("LoRA training run is completed and saved.") - tracked.did_save = True + # Save log + with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file: + json.dump(train_log, file, indent=2) thread = threading.Thread(target=threaded_run) thread.start()