From 852c9437690893bd953afacbb092eb929d139f32 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Mon, 20 May 2024 08:23:47 +0530 Subject: [PATCH] DRY: A modern repetition penalty that reliably prevents looping (#5677) --- extensions/openai/typing.py | 4 ++ modules/loaders.py | 12 ++++ modules/presets.py | 4 ++ modules/sampler_hijack.py | 111 +++++++++++++++++++++++++++++++++--- modules/text_generation.py | 2 +- modules/ui.py | 4 ++ modules/ui_parameters.py | 8 +++ 7 files changed, 137 insertions(+), 8 deletions(-) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 2b30ebf2..4015f6a1 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -33,6 +33,10 @@ class GenerationOptions(BaseModel): seed: int = -1 encoder_repetition_penalty: float = 1 no_repeat_ngram_size: int = 0 + dry_multiplier: float = 0 + dry_base: float = 1.75 + dry_allowed_length: int = 2 + dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"' truncation_length: int = 0 max_tokens_second: int = 0 prompt_lookup_num_tokens: int = 0 diff --git a/modules/loaders.py b/modules/loaders.py index e7ff2d21..fa7595bf 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -178,6 +178,10 @@ def transformers_samplers(): 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', + 'dry_multiplier', + 'dry_base', + 'dry_allowed_length', + 'dry_sequence_breakers', 'seed', 'do_sample', 'penalty_alpha', @@ -251,6 +255,10 @@ loaders_samplers = { 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', + 'dry_multiplier', + 'dry_base', + 'dry_allowed_length', + 'dry_sequence_breakers', 'seed', 'do_sample', 'mirostat_mode', @@ -309,6 +317,10 @@ loaders_samplers = { 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', + 'dry_multiplier', + 'dry_base', + 'dry_allowed_length', + 'dry_sequence_breakers', 'seed', 'do_sample', 'mirostat_mode', diff --git a/modules/presets.py b/modules/presets.py index a7cda31c..b00e829e 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -40,6 +40,10 @@ def default_preset(): 'do_sample': True, 'encoder_repetition_penalty': 1, 'no_repeat_ngram_size': 0, + 'dry_multiplier': 0, + 'dry_base': 1.75, + 'dry_allowed_length': 2, + 'dry_sequence_breakers': '"\\n", ":", "\\"", "*"', 'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat' } diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index da52f4d0..1f9ca52c 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -1,3 +1,4 @@ +import json import math import pprint @@ -220,6 +221,74 @@ class TopALogitsWarper(LogitsWarper): return scores +class DRYLogitsProcessor(LogitsProcessor): + def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int): + self.multiplier = multiplier + self.base = base + self.allowed_length = allowed_length + self.sequence_breakers = sequence_breakers + self._range = _range + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if self._range > 0: + input_ids = input_ids[:, -self._range:] + + for input_ids_row, scores_row in zip(input_ids, scores): + # Raw integer must be extracted here to check for set membership. + last_token = input_ids_row[-1].item() + + if last_token in self.sequence_breakers: + continue + + # Exclude the last token as it always matches. + match_indices = (input_ids_row[:-1] == last_token).nonzero() + + # Stores the maximum matching sequence length + # for each token immediately following the sequence in the input. + match_lengths = {} + + for i in match_indices: + next_token = input_ids_row[i+1].item() + + if next_token in self.sequence_breakers: + continue + + # We have already found that `last_token` matches at this index, + # so the match is at least of length 1. + match_length = 1 + + # Extend the match backwards as far as possible. + while True: + j = i - match_length + if j < 0: + # Start of input reached. + break + + previous_token = input_ids_row[-(match_length+1)].item() + if input_ids_row[j] != previous_token: + # Start of match reached. + break + + if previous_token in self.sequence_breakers: + # Sequence-breaking token reached. + break + + match_length += 1 + + if next_token in match_lengths: + match_lengths[next_token] = max(match_length, match_lengths[next_token]) + else: + match_lengths[next_token] = match_length + + # Apply penalties. + for token, match_length in match_lengths.items(): + if match_length >= self.allowed_length: + penalty = self.multiplier * self.base ** (match_length - self.allowed_length) + scores_row[token] -= penalty + + return scores + + class MirostatLogitsWarper(LogitsWarper): def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): if mirostat_mode not in [2]: @@ -448,20 +517,44 @@ def get_logits_warper_patch(self, generation_config): def get_logits_processor_patch(self, **kwargs): - repetition_penalty = kwargs['generation_config'].repetition_penalty - presence_penalty = kwargs['generation_config'].presence_penalty - frequency_penalty = kwargs['generation_config'].frequency_penalty - repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range - do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0) + generation_config = kwargs['generation_config'] + + do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0) if do_rep_pen_hijack: - kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created + generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created result = self._get_logits_processor_old(**kwargs) if do_rep_pen_hijack: for i in range(len(result)): if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor': - result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, presence_penalty, frequency_penalty, repetition_penalty_range) + result[i] = RepetitionPenaltyLogitsProcessorWithRange( + generation_config.repetition_penalty, + generation_config.presence_penalty, + generation_config.frequency_penalty, + generation_config.repetition_penalty_range + ) + + if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0: + dry_sequence_breakers = generation_config.dry_sequence_breakers + + # Support both JSON array notation and comma-separated strings. + if not dry_sequence_breakers.startswith("["): + dry_sequence_breakers = "[" + dry_sequence_breakers + "]" + + sequence_breaker_strings = json.loads(dry_sequence_breakers) + # Prefix with 'a' to get the correct encoding of the token at the end of a text. + sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings} + + result.append( + DRYLogitsProcessor( + multiplier=generation_config.dry_multiplier, + base=generation_config.dry_base, + allowed_length=generation_config.dry_allowed_length, + sequence_breakers=sequence_breakers, + _range=generation_config.repetition_penalty_range, + ) + ) return result @@ -483,6 +576,10 @@ def generation_config_init_patch(self, **kwargs): self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0) self.presence_penalty = kwargs.pop("presence_penalty", 0) self.frequency_penalty = kwargs.pop("frequency_penalty", 0) + self.dry_multiplier = kwargs.pop("dry_multiplier", 0.0) + self.dry_base = kwargs.pop("dry_base", 1.75) + self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2) + self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"') self.temperature_last = kwargs.pop("temperature_last", False) self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat']) diff --git a/modules/text_generation.py b/modules/text_generation.py index afcdaddb..75294013 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -272,7 +272,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size']: + for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers']: if k in state: generate_params[k] = state[k] diff --git a/modules/ui.py b/modules/ui.py index 37526168..992616de 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -147,6 +147,10 @@ def list_interface_input_elements(): 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', + 'dry_multiplier', + 'dry_base', + 'dry_allowed_length', + 'dry_sequence_breakers', 'do_sample', 'penalty_alpha', 'mirostat_mode', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 0414cdde..d62b74c1 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -38,6 +38,14 @@ def create_ui(default_preset): shared.gradio['presence_penalty'] = gr.Slider(0, 2, value=generate_params['presence_penalty'], step=0.05, label='presence_penalty') shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range') shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') + + with gr.Blocks(): + gr.Markdown("[DRY sequence repetition penalty](https://github.com/oobabooga/text-generation-webui/pull/5677)") + shared.gradio['dry_multiplier'] = gr.Slider(0, 5, value=generate_params['dry_multiplier'], step=0.01, label='dry_multiplier', info='Set to value > 0 to enable DRY. Controls the magnitude of the penalty for the shortest penalized sequences.') + shared.gradio['dry_base'] = gr.Slider(1, 4, value=generate_params['dry_base'], step=0.01, label='dry_base', info='Controls how fast the penalty grows with increasing sequence length.') + shared.gradio['dry_allowed_length'] = gr.Slider(1, 20, value=generate_params['dry_allowed_length'], step=1, label='dry_allowed_length', info='Longest sequence that can be repeated without being penalized.') + shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.') + gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)") with gr.Column():