Fix the early stopping callback #559

This commit is contained in:
oobabooga 2023-03-25 12:35:52 -03:00 committed by GitHub
parent a1f12d607f
commit 8c8e8b4450
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,7 +25,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
continue
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids[i], window)):
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
return True
return False