mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
optimize stopping strings processing (#1625)
This commit is contained in:
parent
e6a78c00f2
commit
0df0b2d0f9
@ -9,25 +9,30 @@ import transformers
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
|
||||
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
self.sentinel_token_ids = sentinel_token_ids
|
||||
self.starting_idx = starting_idx
|
||||
self.shortest = min([x.shape[-1] for x in sentinel_token_ids])
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||
for sample in input_ids:
|
||||
trimmed_sample = sample[self.starting_idx:]
|
||||
trimmed_len = trimmed_sample.shape[-1]
|
||||
if trimmed_len < self.shortest:
|
||||
continue
|
||||
|
||||
for i in range(len(self.sentinel_token_ids)):
|
||||
# Can't unfold, output is still too tiny. Skip.
|
||||
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
||||
for sentinel in self.sentinel_token_ids:
|
||||
sentinel_len = sentinel.shape[-1]
|
||||
if trimmed_len < sentinel_len:
|
||||
continue
|
||||
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
||||
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
||||
return True
|
||||
|
||||
window = trimmed_sample[-sentinel_len:]
|
||||
if torch.all(torch.eq(sentinel, window)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user