Fix the sampling monkey patch (and add more options to sampler_priority) (#6411)

This commit is contained in:
oobabooga 2024-09-27 19:03:25 -03:00 committed by GitHub
parent c497a32372
commit 7424f789bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 116 additions and 78 deletions

View File

@ -44,7 +44,7 @@ def default_preset():
'dry_base': 1.75,
'dry_allowed_length': 2,
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nencoder_repetition_penalty\nno_repeat_ngram'
}

View File

@ -323,62 +323,141 @@ class SpyLogitsWarper(LogitsWarper):
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
'''
def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int):
def __init__(self, penalty: float, _range: int):
if not (penalty > 0):
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
self.penalty = penalty
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self._range = _range
def apply_repetition_penalty(self, input_ids_row, scores_row):
unique_ids = torch.unique(input_ids_row)
score = torch.gather(scores_row, 0, unique_ids)
# Apply multiplicative repetition penalty
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_row.scatter_(0, unique_ids, score)
return scores_row
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
# We loop here because torch.unique() needs to process each row separately in the
# case that batch_size > 1.
for input_ids_row, scores_row in zip(input_ids, scores):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
score = torch.gather(scores_row, 0, unique_ids)
# multiplicative repetition penalty
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_row.scatter_(0, unique_ids, score)
# presence_penalty and frequency_penalty
raw_presence_penalty = (counts > 0).to(scores.dtype)
raw_frequency_penalty = counts.to(scores.dtype)
additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
scores_row.scatter_add_(0, unique_ids, -additive_penalty)
scores_row = self.apply_repetition_penalty(input_ids_row, scores_row)
return scores
def get_logits_warper_patch(self, generation_config, **kwargs):
class PresencePenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, presence_penalty: float, _range: int):
self.presence_penalty = presence_penalty
self._range = _range
def apply_presence_penalty(self, input_ids_row, scores_row):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
# Apply presence penalty
raw_presence_penalty = (counts > 0).to(scores_row.dtype)
presence_penalty = raw_presence_penalty * self.presence_penalty
scores_row.scatter_add_(0, unique_ids, -presence_penalty)
return scores_row
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
scores_row = self.apply_presence_penalty(input_ids_row, scores_row)
return scores
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, frequency_penalty: float, _range: int):
self.frequency_penalty = frequency_penalty
self._range = _range
def apply_frequency_penalty(self, input_ids_row, scores_row):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
# Apply frequency penalty
raw_frequency_penalty = counts.to(scores_row.dtype)
frequency_penalty = raw_frequency_penalty * self.frequency_penalty
scores_row.scatter_add_(0, unique_ids, -frequency_penalty)
return scores_row
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
scores_row = self.apply_frequency_penalty(input_ids_row, scores_row)
return scores
def get_logits_processor_patch(self, **kwargs):
generation_config = kwargs['generation_config']
# Parameter sanitization
if isinstance(generation_config.temperature, int):
generation_config.temperature = float(generation_config.temperature) # Must be float
# Get the original warpers
warpers = self._get_logits_warper_old(generation_config, **kwargs)
warpers = self._get_logits_processor_old(**kwargs)
# Replace temperature with our modified class.
# Currently, it behaves identically to the original.
for i in range(len(warpers)):
for i in range(len(warpers) - 1, -1, -1):
# Replace temperature with our modified class.
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperCustom(
generation_config.temperature,
)
# Stuff we don't need
elif warpers[i].__class__.__name__ in ['SuppressTokensLogitsProcessor', 'RepetitionPenaltyLogitsProcessor']:
del warpers[i]
# Add custom warpers
warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
warpers_to_add.append(
RepetitionPenaltyLogitsProcessorWithRange(
penalty=generation_config.repetition_penalty,
_range=generation_config.repetition_penalty_range
)
)
if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0:
warpers_to_add.append(
PresencePenaltyLogitsProcessor(
presence_penalty=generation_config.presence_penalty,
_range=generation_config.repetition_penalty_range
)
)
if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0:
warpers_to_add.append(
FrequencyPenaltyLogitsProcessor(
frequency_penalty=generation_config.frequency_penalty,
_range=generation_config.repetition_penalty_range
)
)
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
dry_sequence_breakers = generation_config.dry_sequence_breakers
# Support both JSON array notation and comma-separated strings.
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
warpers.append(
DRYLogitsProcessor(
multiplier=generation_config.dry_multiplier,
base=generation_config.dry_base,
allowed_length=generation_config.dry_allowed_length,
sequence_breakers=sequence_breakers,
_range=generation_config.repetition_penalty_range,
)
)
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
warpers_to_add.append(
TailFreeLogitsWarper(
@ -454,7 +533,12 @@ def get_logits_warper_patch(self, generation_config, **kwargs):
'TopALogitsWarper': 'top_a',
'TopKLogitsWarper': 'top_k',
'TopPLogitsWarper': 'top_p',
'TypicalLogitsWarper': 'typical_p'
'TypicalLogitsWarper': 'typical_p',
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
'PresencePenaltyLogitsProcessor': 'presence_penalty',
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty',
'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram',
}
def custom_sort_key(obj):
@ -482,49 +566,6 @@ def get_logits_warper_patch(self, generation_config, **kwargs):
return warpers
def get_logits_processor_patch(self, **kwargs):
generation_config = kwargs['generation_config']
do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0)
if do_rep_pen_hijack:
generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
result = self._get_logits_processor_old(**kwargs)
if do_rep_pen_hijack:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(
generation_config.repetition_penalty,
generation_config.presence_penalty,
generation_config.frequency_penalty,
generation_config.repetition_penalty_range
)
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
dry_sequence_breakers = generation_config.dry_sequence_breakers
# Support both JSON array notation and comma-separated strings.
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
result.append(
DRYLogitsProcessor(
multiplier=generation_config.dry_multiplier,
base=generation_config.dry_base,
allowed_length=generation_config.dry_allowed_length,
sequence_breakers=sequence_breakers,
_range=generation_config.repetition_penalty_range,
)
)
return result
def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
@ -547,13 +588,10 @@ def generation_config_init_patch(self, **kwargs):
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
self.temperature_last = kwargs.pop("temperature_last", False)
self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat'])
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'encoder_repetition_penalty', 'no_repeat_ngram'])
def hijack_samplers():
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch