mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
DRY sampler improvements (#6053)
This commit is contained in:
parent
b675151f25
commit
3abafee696
@ -204,21 +204,25 @@ class DRYLogitsProcessor(LogitsProcessor):
|
||||
input_ids = input_ids[:, -self._range:]
|
||||
|
||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||
# Raw integer must be extracted here to check for set membership.
|
||||
last_token = input_ids_row[-1].item()
|
||||
# Use normal Python data types for improved performance
|
||||
input_ids = input_ids_row.tolist()
|
||||
|
||||
last_token = input_ids[-1]
|
||||
if last_token in self.sequence_breakers:
|
||||
continue
|
||||
|
||||
# Exclude the last token as it always matches.
|
||||
match_indices = (input_ids_row[:-1] == last_token).nonzero()
|
||||
match_indices = []
|
||||
for idx, val in enumerate(input_ids[:-1]):
|
||||
if val == last_token:
|
||||
match_indices.append(idx)
|
||||
|
||||
# Stores the maximum matching sequence length
|
||||
# for each token immediately following the sequence in the input.
|
||||
match_lengths = {}
|
||||
|
||||
for i in match_indices:
|
||||
next_token = input_ids_row[i + 1].item()
|
||||
next_token = input_ids[i + 1]
|
||||
|
||||
if next_token in self.sequence_breakers:
|
||||
continue
|
||||
@ -227,15 +231,15 @@ class DRYLogitsProcessor(LogitsProcessor):
|
||||
# so the match is at least of length 1.
|
||||
match_length = 1
|
||||
|
||||
# Extend the match backwards as far as possible.
|
||||
while True:
|
||||
# Extend the match backwards (at most to 50 to prevent exponent overflow at penalty calculation) (this cap also improves performance on worst case)
|
||||
while match_length < 50:
|
||||
j = i - match_length
|
||||
if j < 0:
|
||||
# Start of input reached.
|
||||
break
|
||||
|
||||
previous_token = input_ids_row[-(match_length + 1)].item()
|
||||
if input_ids_row[j] != previous_token:
|
||||
previous_token = input_ids[-(match_length + 1)]
|
||||
if input_ids[j] != previous_token:
|
||||
# Start of match reached.
|
||||
break
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user