Some qol changes to "Perplexity evaluation"

This commit is contained in:
oobabooga 2023-05-25 15:06:22 -03:00
parent 8efdc01ffb
commit acfd876f29
2 changed files with 6 additions and 3 deletions

View File

@ -78,7 +78,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
yield cumulative_log yield cumulative_log
continue continue
cumulative_log += f"Processing {model}...\n\n" cumulative_log += f"Processing {shared.model_name}...\n\n"
yield cumulative_log + "Tokenizing the input dataset...\n\n" yield cumulative_log + "Tokenizing the input dataset...\n\n"
encodings = encode(text, add_special_tokens=False) encodings = encode(text, add_special_tokens=False)
seq_len = encodings.shape[1] seq_len = encodings.shape[1]
@ -110,7 +110,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
ppl = torch.exp(torch.stack(nlls).mean()) ppl = torch.exp(torch.stack(nlls).mean())
add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length) add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length)
save_past_evaluations(past_evaluations) save_past_evaluations(past_evaluations)
cumulative_log += f"Done. The perplexity is: {float(ppl)}\n\n" cumulative_log += f"The perplexity for {shared.model_name} is: {float(ppl)}\n\n"
yield cumulative_log yield cumulative_log

View File

@ -126,7 +126,9 @@ def create_train_interface():
evaluation_log = gr.Markdown(value='') evaluation_log = gr.Markdown(value='')
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True) evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
save_comments = gr.Button('Save comments') with gr.Row():
save_comments = gr.Button('Save comments', elem_classes="small-button")
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
# Training events # Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after] all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after]
@ -147,6 +149,7 @@ def create_train_interface():
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False) start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False) stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
refresh_table.click(generate_markdown_table, None, evaluation_table, show_progress=True)
save_comments.click( save_comments.click(
save_past_evaluations, evaluation_table, None).then( save_past_evaluations, evaluation_table, None).then(
lambda: "Comments saved.", None, evaluation_log, show_progress=False) lambda: "Comments saved.", None, evaluation_log, show_progress=False)