mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
DRY sampler improvements (#6053)
This commit is contained in:
parent
b675151f25
commit
3abafee696
@ -204,21 +204,25 @@ class DRYLogitsProcessor(LogitsProcessor):
|
|||||||
input_ids = input_ids[:, -self._range:]
|
input_ids = input_ids[:, -self._range:]
|
||||||
|
|
||||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||||
# Raw integer must be extracted here to check for set membership.
|
# Use normal Python data types for improved performance
|
||||||
last_token = input_ids_row[-1].item()
|
input_ids = input_ids_row.tolist()
|
||||||
|
|
||||||
|
last_token = input_ids[-1]
|
||||||
if last_token in self.sequence_breakers:
|
if last_token in self.sequence_breakers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Exclude the last token as it always matches.
|
# Exclude the last token as it always matches.
|
||||||
match_indices = (input_ids_row[:-1] == last_token).nonzero()
|
match_indices = []
|
||||||
|
for idx, val in enumerate(input_ids[:-1]):
|
||||||
|
if val == last_token:
|
||||||
|
match_indices.append(idx)
|
||||||
|
|
||||||
# Stores the maximum matching sequence length
|
# Stores the maximum matching sequence length
|
||||||
# for each token immediately following the sequence in the input.
|
# for each token immediately following the sequence in the input.
|
||||||
match_lengths = {}
|
match_lengths = {}
|
||||||
|
|
||||||
for i in match_indices:
|
for i in match_indices:
|
||||||
next_token = input_ids_row[i + 1].item()
|
next_token = input_ids[i + 1]
|
||||||
|
|
||||||
if next_token in self.sequence_breakers:
|
if next_token in self.sequence_breakers:
|
||||||
continue
|
continue
|
||||||
@ -227,15 +231,15 @@ class DRYLogitsProcessor(LogitsProcessor):
|
|||||||
# so the match is at least of length 1.
|
# so the match is at least of length 1.
|
||||||
match_length = 1
|
match_length = 1
|
||||||
|
|
||||||
# Extend the match backwards as far as possible.
|
# Extend the match backwards (at most to 50 to prevent exponent overflow at penalty calculation) (this cap also improves performance on worst case)
|
||||||
while True:
|
while match_length < 50:
|
||||||
j = i - match_length
|
j = i - match_length
|
||||||
if j < 0:
|
if j < 0:
|
||||||
# Start of input reached.
|
# Start of input reached.
|
||||||
break
|
break
|
||||||
|
|
||||||
previous_token = input_ids_row[-(match_length + 1)].item()
|
previous_token = input_ids[-(match_length + 1)]
|
||||||
if input_ids_row[j] != previous_token:
|
if input_ids[j] != previous_token:
|
||||||
# Start of match reached.
|
# Start of match reached.
|
||||||
break
|
break
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user