diff --git a/modules/evaluate.py b/modules/evaluate.py index d93e81b4..61e30261 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -78,7 +78,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): yield cumulative_log continue - cumulative_log += f"Processing {model}...\n\n" + cumulative_log += f"Processing {shared.model_name}...\n\n" yield cumulative_log + "Tokenizing the input dataset...\n\n" encodings = encode(text, add_special_tokens=False) seq_len = encodings.shape[1] @@ -110,7 +110,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): ppl = torch.exp(torch.stack(nlls).mean()) add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length) save_past_evaluations(past_evaluations) - cumulative_log += f"Done. The perplexity is: {float(ppl)}\n\n" + cumulative_log += f"The perplexity for {shared.model_name} is: {float(ppl)}\n\n" yield cumulative_log diff --git a/modules/training.py b/modules/training.py index 3de85a52..f86fa5a4 100644 --- a/modules/training.py +++ b/modules/training.py @@ -126,7 +126,9 @@ def create_train_interface(): evaluation_log = gr.Markdown(value='') evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True) - save_comments = gr.Button('Save comments') + with gr.Row(): + save_comments = gr.Button('Save comments', elem_classes="small-button") + refresh_table = gr.Button('Refresh the table', elem_classes="small-button") # Training events all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after] @@ -147,6 +149,7 @@ def create_train_interface(): start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False) stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False) + refresh_table.click(generate_markdown_table, None, evaluation_table, show_progress=True) save_comments.click( save_past_evaluations, evaluation_table, None).then( lambda: "Comments saved.", None, evaluation_log, show_progress=False)