From 7424f789bf699c9b2efc874da2b84168bb9e7a6f Mon Sep 17 00:00:00 2001 From: oobabooga Date: Fri, 27 Sep 2024 19:03:25 -0300 Subject: [PATCH] Fix the sampling monkey patch (and add more options to sampler_priority) (#6411) --- modules/presets.py | 2 +- modules/sampler_hijack.py | 192 +++++++++++++++++++++++--------------- 2 files changed, 116 insertions(+), 78 deletions(-) diff --git a/modules/presets.py b/modules/presets.py index b00e829e..14c5a777 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -44,7 +44,7 @@ def default_preset(): 'dry_base': 1.75, 'dry_allowed_length': 2, 'dry_sequence_breakers': '"\\n", ":", "\\"", "*"', - 'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat' + 'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nencoder_repetition_penalty\nno_repeat_ngram' } diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 9fb661ae..9e865bb3 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -323,62 +323,141 @@ class SpyLogitsWarper(LogitsWarper): class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): - ''' - Copied from the transformers library - ''' - - def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int): + def __init__(self, penalty: float, _range: int): if not (penalty > 0): raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}") - self.penalty = penalty - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty self._range = _range + def apply_repetition_penalty(self, input_ids_row, scores_row): + unique_ids = torch.unique(input_ids_row) + score = torch.gather(scores_row, 0, unique_ids) + + # Apply multiplicative repetition penalty + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + scores_row.scatter_(0, unique_ids, score) + return scores_row + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: input_ids = input_ids[:, -self._range:] - - # We loop here because torch.unique() needs to process each row separately in the - # case that batch_size > 1. for input_ids_row, scores_row in zip(input_ids, scores): - unique_ids, counts = torch.unique(input_ids_row, return_counts=True) - score = torch.gather(scores_row, 0, unique_ids) - - # multiplicative repetition penalty - # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where(score < 0, score * self.penalty, score / self.penalty) - scores_row.scatter_(0, unique_ids, score) - - # presence_penalty and frequency_penalty - raw_presence_penalty = (counts > 0).to(scores.dtype) - raw_frequency_penalty = counts.to(scores.dtype) - additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty - scores_row.scatter_add_(0, unique_ids, -additive_penalty) + scores_row = self.apply_repetition_penalty(input_ids_row, scores_row) return scores -def get_logits_warper_patch(self, generation_config, **kwargs): +class PresencePenaltyLogitsProcessor(LogitsProcessor): + def __init__(self, presence_penalty: float, _range: int): + self.presence_penalty = presence_penalty + self._range = _range + + def apply_presence_penalty(self, input_ids_row, scores_row): + unique_ids, counts = torch.unique(input_ids_row, return_counts=True) + + # Apply presence penalty + raw_presence_penalty = (counts > 0).to(scores_row.dtype) + presence_penalty = raw_presence_penalty * self.presence_penalty + scores_row.scatter_add_(0, unique_ids, -presence_penalty) + return scores_row + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + input_ids = input_ids[:, -self._range:] + for input_ids_row, scores_row in zip(input_ids, scores): + scores_row = self.apply_presence_penalty(input_ids_row, scores_row) + return scores + + +class FrequencyPenaltyLogitsProcessor(LogitsProcessor): + def __init__(self, frequency_penalty: float, _range: int): + self.frequency_penalty = frequency_penalty + self._range = _range + + def apply_frequency_penalty(self, input_ids_row, scores_row): + unique_ids, counts = torch.unique(input_ids_row, return_counts=True) + + # Apply frequency penalty + raw_frequency_penalty = counts.to(scores_row.dtype) + frequency_penalty = raw_frequency_penalty * self.frequency_penalty + scores_row.scatter_add_(0, unique_ids, -frequency_penalty) + return scores_row + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + input_ids = input_ids[:, -self._range:] + for input_ids_row, scores_row in zip(input_ids, scores): + scores_row = self.apply_frequency_penalty(input_ids_row, scores_row) + return scores + + +def get_logits_processor_patch(self, **kwargs): + generation_config = kwargs['generation_config'] # Parameter sanitization if isinstance(generation_config.temperature, int): generation_config.temperature = float(generation_config.temperature) # Must be float # Get the original warpers - warpers = self._get_logits_warper_old(generation_config, **kwargs) + warpers = self._get_logits_processor_old(**kwargs) - # Replace temperature with our modified class. - # Currently, it behaves identically to the original. - for i in range(len(warpers)): + for i in range(len(warpers) - 1, -1, -1): + # Replace temperature with our modified class. if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper': warpers[i] = TemperatureLogitsWarperCustom( generation_config.temperature, ) + # Stuff we don't need + elif warpers[i].__class__.__name__ in ['SuppressTokensLogitsProcessor', 'RepetitionPenaltyLogitsProcessor']: + del warpers[i] + # Add custom warpers warpers_to_add = LogitsProcessorList() min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1 + + if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: + warpers_to_add.append( + RepetitionPenaltyLogitsProcessorWithRange( + penalty=generation_config.repetition_penalty, + _range=generation_config.repetition_penalty_range + ) + ) + + if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0: + warpers_to_add.append( + PresencePenaltyLogitsProcessor( + presence_penalty=generation_config.presence_penalty, + _range=generation_config.repetition_penalty_range + ) + ) + + if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0: + warpers_to_add.append( + FrequencyPenaltyLogitsProcessor( + frequency_penalty=generation_config.frequency_penalty, + _range=generation_config.repetition_penalty_range + ) + ) + + if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0: + dry_sequence_breakers = generation_config.dry_sequence_breakers + + # Support both JSON array notation and comma-separated strings. + if not dry_sequence_breakers.startswith("["): + dry_sequence_breakers = "[" + dry_sequence_breakers + "]" + + sequence_breaker_strings = json.loads(dry_sequence_breakers) + # Prefix with 'a' to get the correct encoding of the token at the end of a text. + sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings} + + warpers.append( + DRYLogitsProcessor( + multiplier=generation_config.dry_multiplier, + base=generation_config.dry_base, + allowed_length=generation_config.dry_allowed_length, + sequence_breakers=sequence_breakers, + _range=generation_config.repetition_penalty_range, + ) + ) + if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0: warpers_to_add.append( TailFreeLogitsWarper( @@ -454,7 +533,12 @@ def get_logits_warper_patch(self, generation_config, **kwargs): 'TopALogitsWarper': 'top_a', 'TopKLogitsWarper': 'top_k', 'TopPLogitsWarper': 'top_p', - 'TypicalLogitsWarper': 'typical_p' + 'TypicalLogitsWarper': 'typical_p', + 'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty', + 'PresencePenaltyLogitsProcessor': 'presence_penalty', + 'FrequencyPenaltyLogitsProcessor': 'frequency_penalty', + 'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty', + 'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram', } def custom_sort_key(obj): @@ -482,49 +566,6 @@ def get_logits_warper_patch(self, generation_config, **kwargs): return warpers -def get_logits_processor_patch(self, **kwargs): - generation_config = kwargs['generation_config'] - - do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0) - if do_rep_pen_hijack: - generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created - - result = self._get_logits_processor_old(**kwargs) - - if do_rep_pen_hijack: - for i in range(len(result)): - if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor': - result[i] = RepetitionPenaltyLogitsProcessorWithRange( - generation_config.repetition_penalty, - generation_config.presence_penalty, - generation_config.frequency_penalty, - generation_config.repetition_penalty_range - ) - - if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0: - dry_sequence_breakers = generation_config.dry_sequence_breakers - - # Support both JSON array notation and comma-separated strings. - if not dry_sequence_breakers.startswith("["): - dry_sequence_breakers = "[" + dry_sequence_breakers + "]" - - sequence_breaker_strings = json.loads(dry_sequence_breakers) - # Prefix with 'a' to get the correct encoding of the token at the end of a text. - sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings} - - result.append( - DRYLogitsProcessor( - multiplier=generation_config.dry_multiplier, - base=generation_config.dry_base, - allowed_length=generation_config.dry_allowed_length, - sequence_breakers=sequence_breakers, - _range=generation_config.repetition_penalty_range, - ) - ) - - return result - - def generation_config_init_patch(self, **kwargs): self.__init___old(**kwargs) self.min_p = kwargs.pop("min_p", 0.0) @@ -547,13 +588,10 @@ def generation_config_init_patch(self, **kwargs): self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2) self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"') self.temperature_last = kwargs.pop("temperature_last", False) - self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat']) + self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'encoder_repetition_penalty', 'no_repeat_ngram']) def hijack_samplers(): - transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper - transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch - transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch