From 75adc110d41e6904d853a08aacdf28ea0c3bc97f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 23 May 2023 01:54:52 -0300 Subject: [PATCH] Fix "perplexity evaluation" progress messages --- modules/evaluate.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/modules/evaluate.py b/modules/evaluate.py index adafa713..d93e81b4 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -28,7 +28,9 @@ past_evaluations = load_past_evaluations() def save_past_evaluations(df): global past_evaluations past_evaluations = df - df.to_csv(Path('logs/evaluations.csv'), index=False) + filepath = Path('logs/evaluations.csv') + filepath.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(filepath, index=False) def calculate_perplexity(models, input_dataset, stride, _max_length): @@ -39,7 +41,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): global past_evaluations cumulative_log = '' - cumulative_log += "Loading the input dataset...\n" + cumulative_log += "Loading the input dataset...\n\n" yield cumulative_log # Copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/utils/datautils.py @@ -58,13 +60,13 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): for model in models: if is_in_past_evaluations(model, input_dataset, stride, _max_length): - cumulative_log += f"{model} has already been tested. Ignoring.\n" + cumulative_log += f"{model} has already been tested. Ignoring.\n\n" yield cumulative_log continue if model != 'current model': try: - yield cumulative_log + f"Loading {model}...\n" + yield cumulative_log + f"Loading {model}...\n\n" model_settings = get_model_specific_settings(model) shared.settings.update(model_settings) # hijacking the interface defaults update_model_parameters(model_settings) # hijacking the command-line arguments @@ -72,12 +74,12 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): unload_model() shared.model, shared.tokenizer = load_model(shared.model_name) except: - cumulative_log += f"Failed to load {model}. Moving on.\n" + cumulative_log += f"Failed to load {model}. Moving on.\n\n" yield cumulative_log continue - cumulative_log += f"Processing {model}...\n" - yield cumulative_log + "Tokenizing the input dataset...\n" + cumulative_log += f"Processing {model}...\n\n" + yield cumulative_log + "Tokenizing the input dataset...\n\n" encodings = encode(text, add_special_tokens=False) seq_len = encodings.shape[1] max_length = _max_length or shared.model.config.max_position_embeddings