DRY: A modern repetition penalty that reliably prevents looping (#5677)

This commit is contained in:
Philipp Emanuel Weidmann 2024-05-20 08:23:47 +05:30 committed by GitHub
parent 9f77ed1b98
commit 852c943769
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 137 additions and 8 deletions

View File

@ -33,6 +33,10 @@ class GenerationOptions(BaseModel):
seed: int = -1
encoder_repetition_penalty: float = 1
no_repeat_ngram_size: int = 0
dry_multiplier: float = 0
dry_base: float = 1.75
dry_allowed_length: int = 2
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
truncation_length: int = 0
max_tokens_second: int = 0
prompt_lookup_num_tokens: int = 0

View File

@ -178,6 +178,10 @@ def transformers_samplers():
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'seed',
'do_sample',
'penalty_alpha',
@ -251,6 +255,10 @@ loaders_samplers = {
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'seed',
'do_sample',
'mirostat_mode',
@ -309,6 +317,10 @@ loaders_samplers = {
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'seed',
'do_sample',
'mirostat_mode',

View File

@ -40,6 +40,10 @@ def default_preset():
'do_sample': True,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'dry_multiplier': 0,
'dry_base': 1.75,
'dry_allowed_length': 2,
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
}

View File

@ -1,3 +1,4 @@
import json
import math
import pprint
@ -220,6 +221,74 @@ class TopALogitsWarper(LogitsWarper):
return scores
class DRYLogitsProcessor(LogitsProcessor):
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
self.multiplier = multiplier
self.base = base
self.allowed_length = allowed_length
self.sequence_breakers = sequence_breakers
self._range = _range
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self._range > 0:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
# Raw integer must be extracted here to check for set membership.
last_token = input_ids_row[-1].item()
if last_token in self.sequence_breakers:
continue
# Exclude the last token as it always matches.
match_indices = (input_ids_row[:-1] == last_token).nonzero()
# Stores the maximum matching sequence length
# for each token immediately following the sequence in the input.
match_lengths = {}
for i in match_indices:
next_token = input_ids_row[i+1].item()
if next_token in self.sequence_breakers:
continue
# We have already found that `last_token` matches at this index,
# so the match is at least of length 1.
match_length = 1
# Extend the match backwards as far as possible.
while True:
j = i - match_length
if j < 0:
# Start of input reached.
break
previous_token = input_ids_row[-(match_length+1)].item()
if input_ids_row[j] != previous_token:
# Start of match reached.
break
if previous_token in self.sequence_breakers:
# Sequence-breaking token reached.
break
match_length += 1
if next_token in match_lengths:
match_lengths[next_token] = max(match_length, match_lengths[next_token])
else:
match_lengths[next_token] = match_length
# Apply penalties.
for token, match_length in match_lengths.items():
if match_length >= self.allowed_length:
penalty = self.multiplier * self.base ** (match_length - self.allowed_length)
scores_row[token] -= penalty
return scores
class MirostatLogitsWarper(LogitsWarper):
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if mirostat_mode not in [2]:
@ -448,20 +517,44 @@ def get_logits_warper_patch(self, generation_config):
def get_logits_processor_patch(self, **kwargs):
repetition_penalty = kwargs['generation_config'].repetition_penalty
presence_penalty = kwargs['generation_config'].presence_penalty
frequency_penalty = kwargs['generation_config'].frequency_penalty
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
generation_config = kwargs['generation_config']
do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0)
if do_rep_pen_hijack:
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
result = self._get_logits_processor_old(**kwargs)
if do_rep_pen_hijack:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, presence_penalty, frequency_penalty, repetition_penalty_range)
result[i] = RepetitionPenaltyLogitsProcessorWithRange(
generation_config.repetition_penalty,
generation_config.presence_penalty,
generation_config.frequency_penalty,
generation_config.repetition_penalty_range
)
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
dry_sequence_breakers = generation_config.dry_sequence_breakers
# Support both JSON array notation and comma-separated strings.
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
result.append(
DRYLogitsProcessor(
multiplier=generation_config.dry_multiplier,
base=generation_config.dry_base,
allowed_length=generation_config.dry_allowed_length,
sequence_breakers=sequence_breakers,
_range=generation_config.repetition_penalty_range,
)
)
return result
@ -483,6 +576,10 @@ def generation_config_init_patch(self, **kwargs):
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
self.presence_penalty = kwargs.pop("presence_penalty", 0)
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
self.dry_multiplier = kwargs.pop("dry_multiplier", 0.0)
self.dry_base = kwargs.pop("dry_base", 1.75)
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
self.temperature_last = kwargs.pop("temperature_last", False)
self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat'])

View File

@ -272,7 +272,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size']:
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers']:
if k in state:
generate_params[k] = state[k]

View File

@ -147,6 +147,10 @@ def list_interface_input_elements():
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'do_sample',
'penalty_alpha',
'mirostat_mode',

View File

@ -38,6 +38,14 @@ def create_ui(default_preset):
shared.gradio['presence_penalty'] = gr.Slider(0, 2, value=generate_params['presence_penalty'], step=0.05, label='presence_penalty')
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
with gr.Blocks():
gr.Markdown("[DRY sequence repetition penalty](https://github.com/oobabooga/text-generation-webui/pull/5677)")
shared.gradio['dry_multiplier'] = gr.Slider(0, 5, value=generate_params['dry_multiplier'], step=0.01, label='dry_multiplier', info='Set to value > 0 to enable DRY. Controls the magnitude of the penalty for the shortest penalized sequences.')
shared.gradio['dry_base'] = gr.Slider(1, 4, value=generate_params['dry_base'], step=0.01, label='dry_base', info='Controls how fast the penalty grows with increasing sequence length.')
shared.gradio['dry_allowed_length'] = gr.Slider(1, 20, value=generate_params['dry_allowed_length'], step=1, label='dry_allowed_length', info='Longest sequence that can be repeated without being penalized.')
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)")
with gr.Column():