mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Clear the torch cache while evaluating
This commit is contained in:
parent
388d1864a6
commit
2d44adbb76
@ -7,7 +7,7 @@ from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import shared
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models import clear_torch_cache, load_model, unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.text_generation import encode
|
||||
|
||||
@ -97,7 +97,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||
input_ids = encodings[:, begin_loc:end_loc]
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[:, :-trg_len] = -100
|
||||
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
outputs = shared.model(input_ids=input_ids, labels=target_ids)
|
||||
|
||||
@ -107,7 +107,6 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||
neg_log_likelihood = outputs.loss
|
||||
|
||||
nlls.append(neg_log_likelihood)
|
||||
|
||||
prev_end_loc = end_loc
|
||||
if end_loc == seq_len:
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user