mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
optimize stopping strings processing (#1625)
This commit is contained in:
parent
e6a78c00f2
commit
0df0b2d0f9
@ -9,25 +9,30 @@ import transformers
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
|
||||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||||
|
|
||||||
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
||||||
transformers.StoppingCriteria.__init__(self)
|
transformers.StoppingCriteria.__init__(self)
|
||||||
self.sentinel_token_ids = sentinel_token_ids
|
self.sentinel_token_ids = sentinel_token_ids
|
||||||
self.starting_idx = starting_idx
|
self.starting_idx = starting_idx
|
||||||
|
self.shortest = min([x.shape[-1] for x in sentinel_token_ids])
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||||
for sample in input_ids:
|
for sample in input_ids:
|
||||||
trimmed_sample = sample[self.starting_idx:]
|
trimmed_sample = sample[self.starting_idx:]
|
||||||
|
trimmed_len = trimmed_sample.shape[-1]
|
||||||
|
if trimmed_len < self.shortest:
|
||||||
|
continue
|
||||||
|
|
||||||
for i in range(len(self.sentinel_token_ids)):
|
for sentinel in self.sentinel_token_ids:
|
||||||
# Can't unfold, output is still too tiny. Skip.
|
sentinel_len = sentinel.shape[-1]
|
||||||
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
if trimmed_len < sentinel_len:
|
||||||
continue
|
continue
|
||||||
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
|
||||||
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
window = trimmed_sample[-sentinel_len:]
|
||||||
return True
|
if torch.all(torch.eq(sentinel, window)):
|
||||||
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user