DRY sampler improvements (#6053)

This commit is contained in:
Belladore 2024-06-13 05:39:11 +03:00 committed by GitHub
parent b675151f25
commit 3abafee696
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -204,21 +204,25 @@ class DRYLogitsProcessor(LogitsProcessor):
input_ids = input_ids[:, -self._range:] input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores): for input_ids_row, scores_row in zip(input_ids, scores):
# Raw integer must be extracted here to check for set membership. # Use normal Python data types for improved performance
last_token = input_ids_row[-1].item() input_ids = input_ids_row.tolist()
last_token = input_ids[-1]
if last_token in self.sequence_breakers: if last_token in self.sequence_breakers:
continue continue
# Exclude the last token as it always matches. # Exclude the last token as it always matches.
match_indices = (input_ids_row[:-1] == last_token).nonzero() match_indices = []
for idx, val in enumerate(input_ids[:-1]):
if val == last_token:
match_indices.append(idx)
# Stores the maximum matching sequence length # Stores the maximum matching sequence length
# for each token immediately following the sequence in the input. # for each token immediately following the sequence in the input.
match_lengths = {} match_lengths = {}
for i in match_indices: for i in match_indices:
next_token = input_ids_row[i + 1].item() next_token = input_ids[i + 1]
if next_token in self.sequence_breakers: if next_token in self.sequence_breakers:
continue continue
@ -227,15 +231,15 @@ class DRYLogitsProcessor(LogitsProcessor):
# so the match is at least of length 1. # so the match is at least of length 1.
match_length = 1 match_length = 1
# Extend the match backwards as far as possible. # Extend the match backwards (at most to 50 to prevent exponent overflow at penalty calculation) (this cap also improves performance on worst case)
while True: while match_length < 50:
j = i - match_length j = i - match_length
if j < 0: if j < 0:
# Start of input reached. # Start of input reached.
break break
previous_token = input_ids_row[-(match_length + 1)].item() previous_token = input_ids[-(match_length + 1)]
if input_ids_row[j] != previous_token: if input_ids[j] != previous_token:
# Start of match reached. # Start of match reached.
break break