mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Added saving of training logs to training_log.json (#2769)
This commit is contained in:
parent
017884132f
commit
ce86f726e9
@ -11,7 +11,7 @@ import gradio as gr
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
|
||||
from peft import (LoraConfig, get_peft_model, prepare_model_for_kbit_training,
|
||||
set_peft_model_state_dict)
|
||||
|
||||
from modules import shared, ui, utils
|
||||
@ -38,6 +38,7 @@ except:
|
||||
"GPTNeoXForCausalLM": "gpt_neox"
|
||||
}
|
||||
|
||||
train_log = {}
|
||||
|
||||
WANT_INTERRUPT = False
|
||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after"]
|
||||
@ -357,7 +358,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
logger.info("Getting model ready...")
|
||||
prepare_model_for_int8_training(shared.model)
|
||||
prepare_model_for_kbit_training(shared.model)
|
||||
|
||||
logger.info("Prepping for training...")
|
||||
config = LoraConfig(
|
||||
@ -406,12 +407,19 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
control.should_training_stop = True
|
||||
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
|
||||
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
|
||||
# Save log
|
||||
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
|
||||
json.dump(train_log, file, indent=2)
|
||||
|
||||
|
||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
tracked.current_steps += 1
|
||||
if WANT_INTERRUPT:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
|
||||
train_log.update(logs)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=lora_model,
|
||||
@ -448,7 +456,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
# == Save parameters for reuse ==
|
||||
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
|
||||
vars = locals()
|
||||
json.dump({x: vars[x] for x in PARAMETERS}, file)
|
||||
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
|
||||
|
||||
# == Main run and monitor loop ==
|
||||
logger.info("Starting training...")
|
||||
@ -462,7 +470,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
logger.info("LoRA training run is completed and saved.")
|
||||
tracked.did_save = True
|
||||
# Save log
|
||||
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
||||
json.dump(train_log, file, indent=2)
|
||||
|
||||
thread = threading.Thread(target=threaded_run)
|
||||
thread.start()
|
||||
|
Loading…
Reference in New Issue
Block a user