mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Exclude Top Choices (XTC): A sampler that boosts creativity, breaks writing clichés, and inhibits non-verbatim repetition (#6335)
This commit is contained in:
parent
3492e33fd5
commit
301375834e
@ -37,6 +37,8 @@ class GenerationOptions(BaseModel):
|
|||||||
dry_base: float = 1.75
|
dry_base: float = 1.75
|
||||||
dry_allowed_length: int = 2
|
dry_allowed_length: int = 2
|
||||||
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
|
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
|
||||||
|
xtc_threshold: float = 0.1
|
||||||
|
xtc_probability: float = 0
|
||||||
truncation_length: int = 0
|
truncation_length: int = 0
|
||||||
max_tokens_second: int = 0
|
max_tokens_second: int = 0
|
||||||
prompt_lookup_num_tokens: int = 0
|
prompt_lookup_num_tokens: int = 0
|
||||||
|
@ -168,6 +168,8 @@ def transformers_samplers():
|
|||||||
'dry_base',
|
'dry_base',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_sequence_breakers',
|
'dry_sequence_breakers',
|
||||||
|
'xtc_threshold',
|
||||||
|
'xtc_probability',
|
||||||
'seed',
|
'seed',
|
||||||
'do_sample',
|
'do_sample',
|
||||||
'penalty_alpha',
|
'penalty_alpha',
|
||||||
@ -242,6 +244,8 @@ loaders_samplers = {
|
|||||||
'dry_base',
|
'dry_base',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_sequence_breakers',
|
'dry_sequence_breakers',
|
||||||
|
'xtc_threshold',
|
||||||
|
'xtc_probability',
|
||||||
'seed',
|
'seed',
|
||||||
'do_sample',
|
'do_sample',
|
||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
@ -304,6 +308,8 @@ loaders_samplers = {
|
|||||||
'dry_base',
|
'dry_base',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_sequence_breakers',
|
'dry_sequence_breakers',
|
||||||
|
'xtc_threshold',
|
||||||
|
'xtc_probability',
|
||||||
'seed',
|
'seed',
|
||||||
'do_sample',
|
'do_sample',
|
||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
|
@ -44,7 +44,9 @@ def default_preset():
|
|||||||
'dry_base': 1.75,
|
'dry_base': 1.75,
|
||||||
'dry_allowed_length': 2,
|
'dry_allowed_length': 2,
|
||||||
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
||||||
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nencoder_repetition_penalty\nno_repeat_ngram'
|
'xtc_threshold': 0.1,
|
||||||
|
'xtc_probability': 0,
|
||||||
|
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import pprint
|
import pprint
|
||||||
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@ -191,6 +192,53 @@ class TopALogitsWarper(LogitsWarper):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
# Exclude Top Choices (XTC)
|
||||||
|
class XTCLogitsWarper(LogitsWarper):
|
||||||
|
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.probability = probability
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.special_token_ids = [
|
||||||
|
shared.tokenizer.encode("\n")[-1],
|
||||||
|
]
|
||||||
|
|
||||||
|
if shared.tokenizer.eos_token_id is not None:
|
||||||
|
self.special_token_ids.append(shared.tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# `random` returns values in the half-open range [0, 1), so setting `probability`
|
||||||
|
# to 0 means the sampler never takes action, while setting it to 1 means the sampler
|
||||||
|
# always takes action.
|
||||||
|
#
|
||||||
|
# Note that while XTC is most intuitively described as "if multiple tokens meet
|
||||||
|
# the threshold, then with probability...", reversing the two conditions is logically
|
||||||
|
# equivalent, and improves performance because processing can immediately be stopped
|
||||||
|
# if the random check fails.
|
||||||
|
if random.random() >= self.probability:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||||
|
probs = sorted_logits.softmax(dim=-1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)
|
||||||
|
|
||||||
|
# This operation sets exactly those indices to `True` for which the next index has
|
||||||
|
# probability above the threshold. Since `probs` is sorted, those are the indices
|
||||||
|
# of all tokens that meet the threshold, *except* the least probable one.
|
||||||
|
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
|
||||||
|
|
||||||
|
# Convert sorted_indices_to_remove to the original indices
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
|
||||||
|
# If newline or EOS tokens would be removed, return the original scores
|
||||||
|
if indices_to_remove[:, self.special_token_ids].any():
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# Otherwise, remove tokens with the mask
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class DRYLogitsProcessor(LogitsProcessor):
|
class DRYLogitsProcessor(LogitsProcessor):
|
||||||
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
|
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
@ -474,6 +522,14 @@ def get_logits_processor_patch(self, **kwargs):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
|
||||||
|
warpers_to_add.append(
|
||||||
|
XTCLogitsWarper(
|
||||||
|
threshold=generation_config.xtc_threshold,
|
||||||
|
probability=generation_config.xtc_probability,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if generation_config.dynamic_temperature:
|
if generation_config.dynamic_temperature:
|
||||||
warpers_to_add.append(
|
warpers_to_add.append(
|
||||||
DynamicTemperatureLogitsWarper(
|
DynamicTemperatureLogitsWarper(
|
||||||
@ -534,6 +590,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||||||
'TopKLogitsWarper': 'top_k',
|
'TopKLogitsWarper': 'top_k',
|
||||||
'TopPLogitsWarper': 'top_p',
|
'TopPLogitsWarper': 'top_p',
|
||||||
'TypicalLogitsWarper': 'typical_p',
|
'TypicalLogitsWarper': 'typical_p',
|
||||||
|
'XTCLogitsWarper': 'xtc',
|
||||||
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
|
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
|
||||||
'PresencePenaltyLogitsProcessor': 'presence_penalty',
|
'PresencePenaltyLogitsProcessor': 'presence_penalty',
|
||||||
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
|
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
|
||||||
@ -588,8 +645,10 @@ def generation_config_init_patch(self, **kwargs):
|
|||||||
self.dry_base = kwargs.pop("dry_base", 1.75)
|
self.dry_base = kwargs.pop("dry_base", 1.75)
|
||||||
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
|
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
|
||||||
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
|
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
|
||||||
|
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
|
||||||
|
self.xtc_probability = kwargs.pop("xtc_probability", 0)
|
||||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||||
|
|
||||||
|
|
||||||
def hijack_samplers():
|
def hijack_samplers():
|
||||||
|
@ -289,7 +289,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
|||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers']:
|
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
|
||||||
if k in state:
|
if k in state:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
|
@ -158,6 +158,8 @@ def list_interface_input_elements():
|
|||||||
'dry_base',
|
'dry_base',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_sequence_breakers',
|
'dry_sequence_breakers',
|
||||||
|
'xtc_threshold',
|
||||||
|
'xtc_probability',
|
||||||
'do_sample',
|
'do_sample',
|
||||||
'penalty_alpha',
|
'penalty_alpha',
|
||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
|
@ -45,6 +45,10 @@ def create_ui(default_preset):
|
|||||||
shared.gradio['dry_base'] = gr.Slider(1, 4, value=generate_params['dry_base'], step=0.01, label='dry_base', info='Controls how fast the penalty grows with increasing sequence length.')
|
shared.gradio['dry_base'] = gr.Slider(1, 4, value=generate_params['dry_base'], step=0.01, label='dry_base', info='Controls how fast the penalty grows with increasing sequence length.')
|
||||||
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
|
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
|
||||||
|
|
||||||
|
with gr.Blocks():
|
||||||
|
shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=generate_params['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.')
|
||||||
|
shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=generate_params['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.')
|
||||||
|
|
||||||
gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)")
|
gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)")
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
Loading…
Reference in New Issue
Block a user