diff --git a/modules/stopping_criteria.py b/modules/stopping_criteria.py new file mode 100644 index 00000000..3baadf6c --- /dev/null +++ b/modules/stopping_criteria.py @@ -0,0 +1,31 @@ +''' +This code was copied from + +https://github.com/PygmalionAI/gradio-ui/ + +''' + +import torch +import transformers + +class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): + + def __init__(self, sentinel_token_ids: torch.LongTensor, + starting_idx: int): + transformers.StoppingCriteria.__init__(self) + self.sentinel_token_ids = sentinel_token_ids + self.starting_idx = starting_idx + + def __call__(self, input_ids: torch.LongTensor, + _scores: torch.FloatTensor) -> bool: + for sample in input_ids: + trimmed_sample = sample[self.starting_idx:] + # Can't unfold, output is still too tiny. Skip. + if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: + continue + + for window in trimmed_sample.unfold( + 0, self.sentinel_token_ids.shape[-1], 1): + if torch.all(torch.eq(self.sentinel_token_ids, window)): + return True + return False