mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Added saving of training logs to training_log.json (#2769)
This commit is contained in:
parent
017884132f
commit
ce86f726e9
@ -11,7 +11,7 @@ import gradio as gr
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
|
from peft import (LoraConfig, get_peft_model, prepare_model_for_kbit_training,
|
||||||
set_peft_model_state_dict)
|
set_peft_model_state_dict)
|
||||||
|
|
||||||
from modules import shared, ui, utils
|
from modules import shared, ui, utils
|
||||||
@ -38,6 +38,7 @@ except:
|
|||||||
"GPTNeoXForCausalLM": "gpt_neox"
|
"GPTNeoXForCausalLM": "gpt_neox"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
train_log = {}
|
||||||
|
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after"]
|
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after"]
|
||||||
@ -357,7 +358,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
logger.info("Getting model ready...")
|
logger.info("Getting model ready...")
|
||||||
prepare_model_for_int8_training(shared.model)
|
prepare_model_for_kbit_training(shared.model)
|
||||||
|
|
||||||
logger.info("Prepping for training...")
|
logger.info("Prepping for training...")
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
@ -406,6 +407,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
|
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
|
||||||
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
|
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
|
||||||
|
# Save log
|
||||||
|
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
|
||||||
|
json.dump(train_log, file, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
tracked.current_steps += 1
|
tracked.current_steps += 1
|
||||||
@ -413,6 +418,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
|
||||||
|
train_log.update(logs)
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=lora_model,
|
model=lora_model,
|
||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
@ -448,7 +456,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
# == Save parameters for reuse ==
|
# == Save parameters for reuse ==
|
||||||
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
|
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
|
||||||
vars = locals()
|
vars = locals()
|
||||||
json.dump({x: vars[x] for x in PARAMETERS}, file)
|
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
|
||||||
|
|
||||||
# == Main run and monitor loop ==
|
# == Main run and monitor loop ==
|
||||||
logger.info("Starting training...")
|
logger.info("Starting training...")
|
||||||
@ -462,7 +470,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
logger.info("LoRA training run is completed and saved.")
|
logger.info("LoRA training run is completed and saved.")
|
||||||
tracked.did_save = True
|
# Save log
|
||||||
|
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
||||||
|
json.dump(train_log, file, indent=2)
|
||||||
|
|
||||||
thread = threading.Thread(target=threaded_run)
|
thread = threading.Thread(target=threaded_run)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
Loading…
Reference in New Issue
Block a user