mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 15:48:04 +01:00
DRY: A modern repetition penalty that reliably prevents looping (#5677)
This commit is contained in:
parent
9f77ed1b98
commit
852c943769
@ -33,6 +33,10 @@ class GenerationOptions(BaseModel):
|
||||
seed: int = -1
|
||||
encoder_repetition_penalty: float = 1
|
||||
no_repeat_ngram_size: int = 0
|
||||
dry_multiplier: float = 0
|
||||
dry_base: float = 1.75
|
||||
dry_allowed_length: int = 2
|
||||
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
|
||||
truncation_length: int = 0
|
||||
max_tokens_second: int = 0
|
||||
prompt_lookup_num_tokens: int = 0
|
||||
|
@ -178,6 +178,10 @@ def transformers_samplers():
|
||||
'repetition_penalty_range',
|
||||
'encoder_repetition_penalty',
|
||||
'no_repeat_ngram_size',
|
||||
'dry_multiplier',
|
||||
'dry_base',
|
||||
'dry_allowed_length',
|
||||
'dry_sequence_breakers',
|
||||
'seed',
|
||||
'do_sample',
|
||||
'penalty_alpha',
|
||||
@ -251,6 +255,10 @@ loaders_samplers = {
|
||||
'repetition_penalty_range',
|
||||
'encoder_repetition_penalty',
|
||||
'no_repeat_ngram_size',
|
||||
'dry_multiplier',
|
||||
'dry_base',
|
||||
'dry_allowed_length',
|
||||
'dry_sequence_breakers',
|
||||
'seed',
|
||||
'do_sample',
|
||||
'mirostat_mode',
|
||||
@ -309,6 +317,10 @@ loaders_samplers = {
|
||||
'repetition_penalty_range',
|
||||
'encoder_repetition_penalty',
|
||||
'no_repeat_ngram_size',
|
||||
'dry_multiplier',
|
||||
'dry_base',
|
||||
'dry_allowed_length',
|
||||
'dry_sequence_breakers',
|
||||
'seed',
|
||||
'do_sample',
|
||||
'mirostat_mode',
|
||||
|
@ -40,6 +40,10 @@ def default_preset():
|
||||
'do_sample': True,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'dry_multiplier': 0,
|
||||
'dry_base': 1.75,
|
||||
'dry_allowed_length': 2,
|
||||
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
||||
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
|
||||
}
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import math
|
||||
import pprint
|
||||
|
||||
@ -220,6 +221,74 @@ class TopALogitsWarper(LogitsWarper):
|
||||
return scores
|
||||
|
||||
|
||||
class DRYLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
|
||||
self.multiplier = multiplier
|
||||
self.base = base
|
||||
self.allowed_length = allowed_length
|
||||
self.sequence_breakers = sequence_breakers
|
||||
self._range = _range
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self._range > 0:
|
||||
input_ids = input_ids[:, -self._range:]
|
||||
|
||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||
# Raw integer must be extracted here to check for set membership.
|
||||
last_token = input_ids_row[-1].item()
|
||||
|
||||
if last_token in self.sequence_breakers:
|
||||
continue
|
||||
|
||||
# Exclude the last token as it always matches.
|
||||
match_indices = (input_ids_row[:-1] == last_token).nonzero()
|
||||
|
||||
# Stores the maximum matching sequence length
|
||||
# for each token immediately following the sequence in the input.
|
||||
match_lengths = {}
|
||||
|
||||
for i in match_indices:
|
||||
next_token = input_ids_row[i+1].item()
|
||||
|
||||
if next_token in self.sequence_breakers:
|
||||
continue
|
||||
|
||||
# We have already found that `last_token` matches at this index,
|
||||
# so the match is at least of length 1.
|
||||
match_length = 1
|
||||
|
||||
# Extend the match backwards as far as possible.
|
||||
while True:
|
||||
j = i - match_length
|
||||
if j < 0:
|
||||
# Start of input reached.
|
||||
break
|
||||
|
||||
previous_token = input_ids_row[-(match_length+1)].item()
|
||||
if input_ids_row[j] != previous_token:
|
||||
# Start of match reached.
|
||||
break
|
||||
|
||||
if previous_token in self.sequence_breakers:
|
||||
# Sequence-breaking token reached.
|
||||
break
|
||||
|
||||
match_length += 1
|
||||
|
||||
if next_token in match_lengths:
|
||||
match_lengths[next_token] = max(match_length, match_lengths[next_token])
|
||||
else:
|
||||
match_lengths[next_token] = match_length
|
||||
|
||||
# Apply penalties.
|
||||
for token, match_length in match_lengths.items():
|
||||
if match_length >= self.allowed_length:
|
||||
penalty = self.multiplier * self.base ** (match_length - self.allowed_length)
|
||||
scores_row[token] -= penalty
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class MirostatLogitsWarper(LogitsWarper):
|
||||
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if mirostat_mode not in [2]:
|
||||
@ -448,20 +517,44 @@ def get_logits_warper_patch(self, generation_config):
|
||||
|
||||
|
||||
def get_logits_processor_patch(self, **kwargs):
|
||||
repetition_penalty = kwargs['generation_config'].repetition_penalty
|
||||
presence_penalty = kwargs['generation_config'].presence_penalty
|
||||
frequency_penalty = kwargs['generation_config'].frequency_penalty
|
||||
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
||||
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
|
||||
generation_config = kwargs['generation_config']
|
||||
|
||||
do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0)
|
||||
if do_rep_pen_hijack:
|
||||
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
|
||||
generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
|
||||
|
||||
result = self._get_logits_processor_old(**kwargs)
|
||||
|
||||
if do_rep_pen_hijack:
|
||||
for i in range(len(result)):
|
||||
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
|
||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, presence_penalty, frequency_penalty, repetition_penalty_range)
|
||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(
|
||||
generation_config.repetition_penalty,
|
||||
generation_config.presence_penalty,
|
||||
generation_config.frequency_penalty,
|
||||
generation_config.repetition_penalty_range
|
||||
)
|
||||
|
||||
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
|
||||
dry_sequence_breakers = generation_config.dry_sequence_breakers
|
||||
|
||||
# Support both JSON array notation and comma-separated strings.
|
||||
if not dry_sequence_breakers.startswith("["):
|
||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
|
||||
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
||||
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
||||
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
|
||||
|
||||
result.append(
|
||||
DRYLogitsProcessor(
|
||||
multiplier=generation_config.dry_multiplier,
|
||||
base=generation_config.dry_base,
|
||||
allowed_length=generation_config.dry_allowed_length,
|
||||
sequence_breakers=sequence_breakers,
|
||||
_range=generation_config.repetition_penalty_range,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -483,6 +576,10 @@ def generation_config_init_patch(self, **kwargs):
|
||||
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
||||
self.presence_penalty = kwargs.pop("presence_penalty", 0)
|
||||
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
|
||||
self.dry_multiplier = kwargs.pop("dry_multiplier", 0.0)
|
||||
self.dry_base = kwargs.pop("dry_base", 1.75)
|
||||
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
|
||||
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
|
||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat'])
|
||||
|
||||
|
@ -272,7 +272,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size']:
|
||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers']:
|
||||
if k in state:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
|
@ -147,6 +147,10 @@ def list_interface_input_elements():
|
||||
'repetition_penalty_range',
|
||||
'encoder_repetition_penalty',
|
||||
'no_repeat_ngram_size',
|
||||
'dry_multiplier',
|
||||
'dry_base',
|
||||
'dry_allowed_length',
|
||||
'dry_sequence_breakers',
|
||||
'do_sample',
|
||||
'penalty_alpha',
|
||||
'mirostat_mode',
|
||||
|
@ -38,6 +38,14 @@ def create_ui(default_preset):
|
||||
shared.gradio['presence_penalty'] = gr.Slider(0, 2, value=generate_params['presence_penalty'], step=0.05, label='presence_penalty')
|
||||
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
|
||||
with gr.Blocks():
|
||||
gr.Markdown("[DRY sequence repetition penalty](https://github.com/oobabooga/text-generation-webui/pull/5677)")
|
||||
shared.gradio['dry_multiplier'] = gr.Slider(0, 5, value=generate_params['dry_multiplier'], step=0.01, label='dry_multiplier', info='Set to value > 0 to enable DRY. Controls the magnitude of the penalty for the shortest penalized sequences.')
|
||||
shared.gradio['dry_base'] = gr.Slider(1, 4, value=generate_params['dry_base'], step=0.01, label='dry_base', info='Controls how fast the penalty grows with increasing sequence length.')
|
||||
shared.gradio['dry_allowed_length'] = gr.Slider(1, 20, value=generate_params['dry_allowed_length'], step=1, label='dry_allowed_length', info='Longest sequence that can be repeated without being penalized.')
|
||||
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
|
||||
|
||||
gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)")
|
||||
|
||||
with gr.Column():
|
||||
|
Loading…
Reference in New Issue
Block a user