mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add Feature to Log Sample of Training Dataset for Inspection (#1711)
This commit is contained in:
parent
b6ba68eda9
commit
73a0def4af
@ -580,7 +580,26 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
yield "Interrupted before start."
|
||||
return
|
||||
|
||||
def log_train_dataset(trainer):
|
||||
decoded_entries = []
|
||||
# Try to decode the entries and write the log file
|
||||
try:
|
||||
# Iterate over the first 10 elements in the dataset (or fewer if there are less than 10)
|
||||
for i in range(min(10, len(trainer.train_dataset))):
|
||||
decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids'])
|
||||
decoded_entries.append({"value": decoded_text})
|
||||
|
||||
# Write the log file
|
||||
Path('logs').mkdir(exist_ok=True)
|
||||
with open(Path('logs/train_dataset_sample.json'), 'w') as json_file:
|
||||
json.dump(decoded_entries, json_file, indent=4)
|
||||
|
||||
logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create log file due to error: {e}")
|
||||
|
||||
def threaded_run():
|
||||
log_train_dataset(trainer)
|
||||
trainer.train()
|
||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
|
Loading…
Reference in New Issue
Block a user