mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 01:09:22 +01:00
Fix the sampling monkey patch (and add more options to sampler_priority) (#6411)
This commit is contained in:
parent
c497a32372
commit
7424f789bf
@ -44,7 +44,7 @@ def default_preset():
|
||||
'dry_base': 1.75,
|
||||
'dry_allowed_length': 2,
|
||||
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
||||
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
|
||||
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nencoder_repetition_penalty\nno_repeat_ngram'
|
||||
}
|
||||
|
||||
|
||||
|
@ -323,62 +323,141 @@ class SpyLogitsWarper(LogitsWarper):
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
'''
|
||||
Copied from the transformers library
|
||||
'''
|
||||
|
||||
def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int):
|
||||
def __init__(self, penalty: float, _range: int):
|
||||
if not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
|
||||
|
||||
self.penalty = penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self._range = _range
|
||||
|
||||
def apply_repetition_penalty(self, input_ids_row, scores_row):
|
||||
unique_ids = torch.unique(input_ids_row)
|
||||
score = torch.gather(scores_row, 0, unique_ids)
|
||||
|
||||
# Apply multiplicative repetition penalty
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
scores_row.scatter_(0, unique_ids, score)
|
||||
return scores_row
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
input_ids = input_ids[:, -self._range:]
|
||||
|
||||
# We loop here because torch.unique() needs to process each row separately in the
|
||||
# case that batch_size > 1.
|
||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
||||
score = torch.gather(scores_row, 0, unique_ids)
|
||||
|
||||
# multiplicative repetition penalty
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
scores_row.scatter_(0, unique_ids, score)
|
||||
|
||||
# presence_penalty and frequency_penalty
|
||||
raw_presence_penalty = (counts > 0).to(scores.dtype)
|
||||
raw_frequency_penalty = counts.to(scores.dtype)
|
||||
additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
|
||||
scores_row.scatter_add_(0, unique_ids, -additive_penalty)
|
||||
scores_row = self.apply_repetition_penalty(input_ids_row, scores_row)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_warper_patch(self, generation_config, **kwargs):
|
||||
class PresencePenaltyLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, presence_penalty: float, _range: int):
|
||||
self.presence_penalty = presence_penalty
|
||||
self._range = _range
|
||||
|
||||
def apply_presence_penalty(self, input_ids_row, scores_row):
|
||||
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
||||
|
||||
# Apply presence penalty
|
||||
raw_presence_penalty = (counts > 0).to(scores_row.dtype)
|
||||
presence_penalty = raw_presence_penalty * self.presence_penalty
|
||||
scores_row.scatter_add_(0, unique_ids, -presence_penalty)
|
||||
return scores_row
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
input_ids = input_ids[:, -self._range:]
|
||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||
scores_row = self.apply_presence_penalty(input_ids_row, scores_row)
|
||||
return scores
|
||||
|
||||
|
||||
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, frequency_penalty: float, _range: int):
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self._range = _range
|
||||
|
||||
def apply_frequency_penalty(self, input_ids_row, scores_row):
|
||||
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
||||
|
||||
# Apply frequency penalty
|
||||
raw_frequency_penalty = counts.to(scores_row.dtype)
|
||||
frequency_penalty = raw_frequency_penalty * self.frequency_penalty
|
||||
scores_row.scatter_add_(0, unique_ids, -frequency_penalty)
|
||||
return scores_row
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
input_ids = input_ids[:, -self._range:]
|
||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||
scores_row = self.apply_frequency_penalty(input_ids_row, scores_row)
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor_patch(self, **kwargs):
|
||||
generation_config = kwargs['generation_config']
|
||||
|
||||
# Parameter sanitization
|
||||
if isinstance(generation_config.temperature, int):
|
||||
generation_config.temperature = float(generation_config.temperature) # Must be float
|
||||
|
||||
# Get the original warpers
|
||||
warpers = self._get_logits_warper_old(generation_config, **kwargs)
|
||||
warpers = self._get_logits_processor_old(**kwargs)
|
||||
|
||||
for i in range(len(warpers) - 1, -1, -1):
|
||||
# Replace temperature with our modified class.
|
||||
# Currently, it behaves identically to the original.
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
warpers[i] = TemperatureLogitsWarperCustom(
|
||||
generation_config.temperature,
|
||||
)
|
||||
|
||||
# Stuff we don't need
|
||||
elif warpers[i].__class__.__name__ in ['SuppressTokensLogitsProcessor', 'RepetitionPenaltyLogitsProcessor']:
|
||||
del warpers[i]
|
||||
|
||||
# Add custom warpers
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||
|
||||
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
|
||||
warpers_to_add.append(
|
||||
RepetitionPenaltyLogitsProcessorWithRange(
|
||||
penalty=generation_config.repetition_penalty,
|
||||
_range=generation_config.repetition_penalty_range
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0:
|
||||
warpers_to_add.append(
|
||||
PresencePenaltyLogitsProcessor(
|
||||
presence_penalty=generation_config.presence_penalty,
|
||||
_range=generation_config.repetition_penalty_range
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0:
|
||||
warpers_to_add.append(
|
||||
FrequencyPenaltyLogitsProcessor(
|
||||
frequency_penalty=generation_config.frequency_penalty,
|
||||
_range=generation_config.repetition_penalty_range
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
|
||||
dry_sequence_breakers = generation_config.dry_sequence_breakers
|
||||
|
||||
# Support both JSON array notation and comma-separated strings.
|
||||
if not dry_sequence_breakers.startswith("["):
|
||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
|
||||
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
||||
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
||||
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
|
||||
|
||||
warpers.append(
|
||||
DRYLogitsProcessor(
|
||||
multiplier=generation_config.dry_multiplier,
|
||||
base=generation_config.dry_base,
|
||||
allowed_length=generation_config.dry_allowed_length,
|
||||
sequence_breakers=sequence_breakers,
|
||||
_range=generation_config.repetition_penalty_range,
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
|
||||
warpers_to_add.append(
|
||||
TailFreeLogitsWarper(
|
||||
@ -454,7 +533,12 @@ def get_logits_warper_patch(self, generation_config, **kwargs):
|
||||
'TopALogitsWarper': 'top_a',
|
||||
'TopKLogitsWarper': 'top_k',
|
||||
'TopPLogitsWarper': 'top_p',
|
||||
'TypicalLogitsWarper': 'typical_p'
|
||||
'TypicalLogitsWarper': 'typical_p',
|
||||
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
|
||||
'PresencePenaltyLogitsProcessor': 'presence_penalty',
|
||||
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
|
||||
'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty',
|
||||
'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram',
|
||||
}
|
||||
|
||||
def custom_sort_key(obj):
|
||||
@ -482,49 +566,6 @@ def get_logits_warper_patch(self, generation_config, **kwargs):
|
||||
return warpers
|
||||
|
||||
|
||||
def get_logits_processor_patch(self, **kwargs):
|
||||
generation_config = kwargs['generation_config']
|
||||
|
||||
do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0)
|
||||
if do_rep_pen_hijack:
|
||||
generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
|
||||
|
||||
result = self._get_logits_processor_old(**kwargs)
|
||||
|
||||
if do_rep_pen_hijack:
|
||||
for i in range(len(result)):
|
||||
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
|
||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(
|
||||
generation_config.repetition_penalty,
|
||||
generation_config.presence_penalty,
|
||||
generation_config.frequency_penalty,
|
||||
generation_config.repetition_penalty_range
|
||||
)
|
||||
|
||||
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
|
||||
dry_sequence_breakers = generation_config.dry_sequence_breakers
|
||||
|
||||
# Support both JSON array notation and comma-separated strings.
|
||||
if not dry_sequence_breakers.startswith("["):
|
||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
|
||||
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
||||
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
||||
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
|
||||
|
||||
result.append(
|
||||
DRYLogitsProcessor(
|
||||
multiplier=generation_config.dry_multiplier,
|
||||
base=generation_config.dry_base,
|
||||
allowed_length=generation_config.dry_allowed_length,
|
||||
sequence_breakers=sequence_breakers,
|
||||
_range=generation_config.repetition_penalty_range,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def generation_config_init_patch(self, **kwargs):
|
||||
self.__init___old(**kwargs)
|
||||
self.min_p = kwargs.pop("min_p", 0.0)
|
||||
@ -547,13 +588,10 @@ def generation_config_init_patch(self, **kwargs):
|
||||
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
|
||||
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
|
||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat'])
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
|
||||
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
|
||||
|
||||
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
|
||||
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user