optimize stopping strings processing (#1625)

This commit is contained in:
Alex "mcmonkey" Goodwin 2023-05-01 21:21:54 -07:00 committed by GitHub
parent e6a78c00f2
commit 0df0b2d0f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,25 +9,30 @@ import transformers
import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: list, starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
self.shortest = min([x.shape[-1] for x in sentinel_token_ids])
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
for i in range(len(self.sentinel_token_ids)):
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
trimmed_len = trimmed_sample.shape[-1]
if trimmed_len < self.shortest:
continue
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
for sentinel in self.sentinel_token_ids:
sentinel_len = sentinel.shape[-1]
if trimmed_len < sentinel_len:
continue
window = trimmed_sample[-sentinel_len:]
if torch.all(torch.eq(sentinel, window)):
return True
return False