mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add Feature to Log Sample of Training Dataset for Inspection (#1711)
This commit is contained in:
parent
b6ba68eda9
commit
73a0def4af
@ -579,8 +579,27 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
yield "Interrupted before start."
|
yield "Interrupted before start."
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def log_train_dataset(trainer):
|
||||||
|
decoded_entries = []
|
||||||
|
# Try to decode the entries and write the log file
|
||||||
|
try:
|
||||||
|
# Iterate over the first 10 elements in the dataset (or fewer if there are less than 10)
|
||||||
|
for i in range(min(10, len(trainer.train_dataset))):
|
||||||
|
decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids'])
|
||||||
|
decoded_entries.append({"value": decoded_text})
|
||||||
|
|
||||||
|
# Write the log file
|
||||||
|
Path('logs').mkdir(exist_ok=True)
|
||||||
|
with open(Path('logs/train_dataset_sample.json'), 'w') as json_file:
|
||||||
|
json.dump(decoded_entries, json_file, indent=4)
|
||||||
|
|
||||||
|
logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create log file due to error: {e}")
|
||||||
|
|
||||||
def threaded_run():
|
def threaded_run():
|
||||||
|
log_train_dataset(trainer)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
Loading…
Reference in New Issue
Block a user