From 2d44adbb762a37b909df9eeeea43e9e151c7c7cf Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 16 Oct 2023 10:52:50 -0700 Subject: [PATCH] Clear the torch cache while evaluating --- modules/evaluate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/evaluate.py b/modules/evaluate.py index a569674e..4b5335ff 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -7,7 +7,7 @@ from datasets import load_dataset from tqdm import tqdm from modules import shared -from modules.models import load_model, unload_model +from modules.models import clear_torch_cache, load_model, unload_model from modules.models_settings import get_model_metadata, update_model_parameters from modules.text_generation import encode @@ -97,7 +97,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): input_ids = encodings[:, begin_loc:end_loc] target_ids = input_ids.clone() target_ids[:, :-trg_len] = -100 - + clear_torch_cache() with torch.no_grad(): outputs = shared.model(input_ids=input_ids, labels=target_ids) @@ -107,7 +107,6 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): neg_log_likelihood = outputs.loss nlls.append(neg_log_likelihood) - prev_end_loc = end_loc if end_loc == seq_len: break