diff --git a/modules/callbacks.py b/modules/callbacks.py index 51ecbdd7..a91c94a0 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -9,25 +9,30 @@ import transformers import modules.shared as shared -# Copied from https://github.com/PygmalionAI/gradio-ui/ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): def __init__(self, sentinel_token_ids: list, starting_idx: int): transformers.StoppingCriteria.__init__(self) self.sentinel_token_ids = sentinel_token_ids self.starting_idx = starting_idx + self.shortest = min([x.shape[-1] for x in sentinel_token_ids]) def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: for sample in input_ids: trimmed_sample = sample[self.starting_idx:] + trimmed_len = trimmed_sample.shape[-1] + if trimmed_len < self.shortest: + continue - for i in range(len(self.sentinel_token_ids)): - # Can't unfold, output is still too tiny. Skip. - if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]: + for sentinel in self.sentinel_token_ids: + sentinel_len = sentinel.shape[-1] + if trimmed_len < sentinel_len: continue - for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1): - if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)): - return True + + window = trimmed_sample[-sentinel_len:] + if torch.all(torch.eq(sentinel, window)): + return True + return False