mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix the sampling monkey patch (and add more options to sampler_priority) (#6411)
This commit is contained in:
parent
c497a32372
commit
7424f789bf
@ -44,7 +44,7 @@ def default_preset():
|
|||||||
'dry_base': 1.75,
|
'dry_base': 1.75,
|
||||||
'dry_allowed_length': 2,
|
'dry_allowed_length': 2,
|
||||||
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
||||||
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
|
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nencoder_repetition_penalty\nno_repeat_ngram'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -323,62 +323,141 @@ class SpyLogitsWarper(LogitsWarper):
|
|||||||
|
|
||||||
|
|
||||||
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||||
'''
|
def __init__(self, penalty: float, _range: int):
|
||||||
Copied from the transformers library
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int):
|
|
||||||
if not (penalty > 0):
|
if not (penalty > 0):
|
||||||
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
|
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
|
||||||
|
|
||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
self.presence_penalty = presence_penalty
|
|
||||||
self.frequency_penalty = frequency_penalty
|
|
||||||
self._range = _range
|
self._range = _range
|
||||||
|
|
||||||
|
def apply_repetition_penalty(self, input_ids_row, scores_row):
|
||||||
|
unique_ids = torch.unique(input_ids_row)
|
||||||
|
score = torch.gather(scores_row, 0, unique_ids)
|
||||||
|
|
||||||
|
# Apply multiplicative repetition penalty
|
||||||
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
scores_row.scatter_(0, unique_ids, score)
|
||||||
|
return scores_row
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
input_ids = input_ids[:, -self._range:]
|
input_ids = input_ids[:, -self._range:]
|
||||||
|
|
||||||
# We loop here because torch.unique() needs to process each row separately in the
|
|
||||||
# case that batch_size > 1.
|
|
||||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||||
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
scores_row = self.apply_repetition_penalty(input_ids_row, scores_row)
|
||||||
score = torch.gather(scores_row, 0, unique_ids)
|
|
||||||
|
|
||||||
# multiplicative repetition penalty
|
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
|
||||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
|
||||||
scores_row.scatter_(0, unique_ids, score)
|
|
||||||
|
|
||||||
# presence_penalty and frequency_penalty
|
|
||||||
raw_presence_penalty = (counts > 0).to(scores.dtype)
|
|
||||||
raw_frequency_penalty = counts.to(scores.dtype)
|
|
||||||
additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
|
|
||||||
scores_row.scatter_add_(0, unique_ids, -additive_penalty)
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def get_logits_warper_patch(self, generation_config, **kwargs):
|
class PresencePenaltyLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, presence_penalty: float, _range: int):
|
||||||
|
self.presence_penalty = presence_penalty
|
||||||
|
self._range = _range
|
||||||
|
|
||||||
|
def apply_presence_penalty(self, input_ids_row, scores_row):
|
||||||
|
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
||||||
|
|
||||||
|
# Apply presence penalty
|
||||||
|
raw_presence_penalty = (counts > 0).to(scores_row.dtype)
|
||||||
|
presence_penalty = raw_presence_penalty * self.presence_penalty
|
||||||
|
scores_row.scatter_add_(0, unique_ids, -presence_penalty)
|
||||||
|
return scores_row
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
input_ids = input_ids[:, -self._range:]
|
||||||
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||||
|
scores_row = self.apply_presence_penalty(input_ids_row, scores_row)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, frequency_penalty: float, _range: int):
|
||||||
|
self.frequency_penalty = frequency_penalty
|
||||||
|
self._range = _range
|
||||||
|
|
||||||
|
def apply_frequency_penalty(self, input_ids_row, scores_row):
|
||||||
|
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
||||||
|
|
||||||
|
# Apply frequency penalty
|
||||||
|
raw_frequency_penalty = counts.to(scores_row.dtype)
|
||||||
|
frequency_penalty = raw_frequency_penalty * self.frequency_penalty
|
||||||
|
scores_row.scatter_add_(0, unique_ids, -frequency_penalty)
|
||||||
|
return scores_row
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
input_ids = input_ids[:, -self._range:]
|
||||||
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||||
|
scores_row = self.apply_frequency_penalty(input_ids_row, scores_row)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_processor_patch(self, **kwargs):
|
||||||
|
generation_config = kwargs['generation_config']
|
||||||
|
|
||||||
# Parameter sanitization
|
# Parameter sanitization
|
||||||
if isinstance(generation_config.temperature, int):
|
if isinstance(generation_config.temperature, int):
|
||||||
generation_config.temperature = float(generation_config.temperature) # Must be float
|
generation_config.temperature = float(generation_config.temperature) # Must be float
|
||||||
|
|
||||||
# Get the original warpers
|
# Get the original warpers
|
||||||
warpers = self._get_logits_warper_old(generation_config, **kwargs)
|
warpers = self._get_logits_processor_old(**kwargs)
|
||||||
|
|
||||||
# Replace temperature with our modified class.
|
for i in range(len(warpers) - 1, -1, -1):
|
||||||
# Currently, it behaves identically to the original.
|
# Replace temperature with our modified class.
|
||||||
for i in range(len(warpers)):
|
|
||||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||||
warpers[i] = TemperatureLogitsWarperCustom(
|
warpers[i] = TemperatureLogitsWarperCustom(
|
||||||
generation_config.temperature,
|
generation_config.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Stuff we don't need
|
||||||
|
elif warpers[i].__class__.__name__ in ['SuppressTokensLogitsProcessor', 'RepetitionPenaltyLogitsProcessor']:
|
||||||
|
del warpers[i]
|
||||||
|
|
||||||
# Add custom warpers
|
# Add custom warpers
|
||||||
warpers_to_add = LogitsProcessorList()
|
warpers_to_add = LogitsProcessorList()
|
||||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||||
|
|
||||||
|
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
|
||||||
|
warpers_to_add.append(
|
||||||
|
RepetitionPenaltyLogitsProcessorWithRange(
|
||||||
|
penalty=generation_config.repetition_penalty,
|
||||||
|
_range=generation_config.repetition_penalty_range
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0:
|
||||||
|
warpers_to_add.append(
|
||||||
|
PresencePenaltyLogitsProcessor(
|
||||||
|
presence_penalty=generation_config.presence_penalty,
|
||||||
|
_range=generation_config.repetition_penalty_range
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0:
|
||||||
|
warpers_to_add.append(
|
||||||
|
FrequencyPenaltyLogitsProcessor(
|
||||||
|
frequency_penalty=generation_config.frequency_penalty,
|
||||||
|
_range=generation_config.repetition_penalty_range
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
|
||||||
|
dry_sequence_breakers = generation_config.dry_sequence_breakers
|
||||||
|
|
||||||
|
# Support both JSON array notation and comma-separated strings.
|
||||||
|
if not dry_sequence_breakers.startswith("["):
|
||||||
|
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||||
|
|
||||||
|
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
||||||
|
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
||||||
|
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
|
||||||
|
|
||||||
|
warpers.append(
|
||||||
|
DRYLogitsProcessor(
|
||||||
|
multiplier=generation_config.dry_multiplier,
|
||||||
|
base=generation_config.dry_base,
|
||||||
|
allowed_length=generation_config.dry_allowed_length,
|
||||||
|
sequence_breakers=sequence_breakers,
|
||||||
|
_range=generation_config.repetition_penalty_range,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
|
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
|
||||||
warpers_to_add.append(
|
warpers_to_add.append(
|
||||||
TailFreeLogitsWarper(
|
TailFreeLogitsWarper(
|
||||||
@ -454,7 +533,12 @@ def get_logits_warper_patch(self, generation_config, **kwargs):
|
|||||||
'TopALogitsWarper': 'top_a',
|
'TopALogitsWarper': 'top_a',
|
||||||
'TopKLogitsWarper': 'top_k',
|
'TopKLogitsWarper': 'top_k',
|
||||||
'TopPLogitsWarper': 'top_p',
|
'TopPLogitsWarper': 'top_p',
|
||||||
'TypicalLogitsWarper': 'typical_p'
|
'TypicalLogitsWarper': 'typical_p',
|
||||||
|
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
|
||||||
|
'PresencePenaltyLogitsProcessor': 'presence_penalty',
|
||||||
|
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
|
||||||
|
'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty',
|
||||||
|
'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram',
|
||||||
}
|
}
|
||||||
|
|
||||||
def custom_sort_key(obj):
|
def custom_sort_key(obj):
|
||||||
@ -482,49 +566,6 @@ def get_logits_warper_patch(self, generation_config, **kwargs):
|
|||||||
return warpers
|
return warpers
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor_patch(self, **kwargs):
|
|
||||||
generation_config = kwargs['generation_config']
|
|
||||||
|
|
||||||
do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0)
|
|
||||||
if do_rep_pen_hijack:
|
|
||||||
generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
|
|
||||||
|
|
||||||
result = self._get_logits_processor_old(**kwargs)
|
|
||||||
|
|
||||||
if do_rep_pen_hijack:
|
|
||||||
for i in range(len(result)):
|
|
||||||
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
|
|
||||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(
|
|
||||||
generation_config.repetition_penalty,
|
|
||||||
generation_config.presence_penalty,
|
|
||||||
generation_config.frequency_penalty,
|
|
||||||
generation_config.repetition_penalty_range
|
|
||||||
)
|
|
||||||
|
|
||||||
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
|
|
||||||
dry_sequence_breakers = generation_config.dry_sequence_breakers
|
|
||||||
|
|
||||||
# Support both JSON array notation and comma-separated strings.
|
|
||||||
if not dry_sequence_breakers.startswith("["):
|
|
||||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
|
||||||
|
|
||||||
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
|
||||||
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
|
||||||
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
|
|
||||||
|
|
||||||
result.append(
|
|
||||||
DRYLogitsProcessor(
|
|
||||||
multiplier=generation_config.dry_multiplier,
|
|
||||||
base=generation_config.dry_base,
|
|
||||||
allowed_length=generation_config.dry_allowed_length,
|
|
||||||
sequence_breakers=sequence_breakers,
|
|
||||||
_range=generation_config.repetition_penalty_range,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def generation_config_init_patch(self, **kwargs):
|
def generation_config_init_patch(self, **kwargs):
|
||||||
self.__init___old(**kwargs)
|
self.__init___old(**kwargs)
|
||||||
self.min_p = kwargs.pop("min_p", 0.0)
|
self.min_p = kwargs.pop("min_p", 0.0)
|
||||||
@ -547,13 +588,10 @@ def generation_config_init_patch(self, **kwargs):
|
|||||||
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
|
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
|
||||||
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
|
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
|
||||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||||
self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat'])
|
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||||
|
|
||||||
|
|
||||||
def hijack_samplers():
|
def hijack_samplers():
|
||||||
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
|
|
||||||
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
|
|
||||||
|
|
||||||
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
|
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
|
||||||
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
|
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user