mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Clear the torch cache while evaluating
This commit is contained in:
parent
388d1864a6
commit
2d44adbb76
@ -7,7 +7,7 @@ from datasets import load_dataset
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import clear_torch_cache, load_model, unload_model
|
||||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||||
from modules.text_generation import encode
|
from modules.text_generation import encode
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
|||||||
input_ids = encodings[:, begin_loc:end_loc]
|
input_ids = encodings[:, begin_loc:end_loc]
|
||||||
target_ids = input_ids.clone()
|
target_ids = input_ids.clone()
|
||||||
target_ids[:, :-trg_len] = -100
|
target_ids[:, :-trg_len] = -100
|
||||||
|
clear_torch_cache()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = shared.model(input_ids=input_ids, labels=target_ids)
|
outputs = shared.model(input_ids=input_ids, labels=target_ids)
|
||||||
|
|
||||||
@ -107,7 +107,6 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
|||||||
neg_log_likelihood = outputs.loss
|
neg_log_likelihood = outputs.loss
|
||||||
|
|
||||||
nlls.append(neg_log_likelihood)
|
nlls.append(neg_log_likelihood)
|
||||||
|
|
||||||
prev_end_loc = end_loc
|
prev_end_loc = end_loc
|
||||||
if end_loc == seq_len:
|
if end_loc == seq_len:
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user