Quadratic sampling (#5403)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
kalomaze 2024-02-03 21:20:02 -06:00 committed by GitHub
parent e98d1086f5
commit b6077b02e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 45 additions and 15 deletions

View File

@ -55,6 +55,7 @@ For more information about the parameters, the [transformers documentation](http
* **mirostat_tau**: No idea, see the paper for details. According to the Preset Arena, 8 is a good value.
* **mirostat_eta**: No idea, see the paper for details. According to the Preset Arena, 0.1 is a good value.
* **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent".
* **smoothing_factor**: Activates Quadratic Sampling. This takes precedence over regular temperature and dynamic temperature, and replaces those samplers. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked.
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency.
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (notably ExLlama v1 and v2). For these loaders, the seed has no effect.

View File

@ -12,6 +12,7 @@ class GenerationOptions(BaseModel):
dynatemp_low: float = 1
dynatemp_high: float = 1
dynatemp_exponent: float = 1
smoothing_factor: float = 0
top_k: int = 0
repetition_penalty: float = 1
repetition_penalty_range: int = 1024

View File

@ -159,6 +159,7 @@ def transformers_samplers():
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'top_p',
'min_p',
'top_k',
@ -233,6 +234,7 @@ loaders_samplers = {
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'top_p',
'min_p',
'top_k',
@ -289,6 +291,7 @@ loaders_samplers = {
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'top_p',
'min_p',
'top_k',

View File

@ -17,6 +17,7 @@ def default_preset():
'dynatemp_low': 1,
'dynatemp_high': 1,
'dynatemp_exponent': 1,
'smoothing_factor': 0,
'top_p': 1,
'min_p': 0,
'top_k': 0,

View File

@ -15,8 +15,12 @@ from modules import shared
global_scores = None
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
class ModifiedTemperatureLogitsWarper(LogitsWarper):
'''
Based on the original Transformers temperature logits warper, this
adds support for dynamic temperature and quadratic sampling.
'''
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
@ -32,16 +36,27 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
self.dynatemp_low = dynatemp_low
self.dynatemp_high = dynatemp_high
self.dynatemp_exponent = dynatemp_exponent
self.smoothing_factor = smoothing_factor
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Regular temperature
if not self.dynamic_temperature:
scores = scores / self.temperature
return scores
# Quadratic sampling
if self.smoothing_factor > 0:
# Compute the maximum logit value
max_logit = scores.max()
# Apply the quadratic transformation
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
# No need to print the top 5 logits since this is not required
# print("Original top 5 logits: ", torch.topk(scores, 5))
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
return transformed_logits
# Dynamic temperature
else:
elif self.dynamic_temperature:
min_temp = self.dynatemp_low
max_temp = self.dynatemp_high
exponent_val = self.dynatemp_exponent
@ -88,6 +103,11 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
return scores
# Regular temperature
else:
scores = scores / self.temperature
return scores
class MinPLogitsWarper(LogitsWarper):
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -286,7 +306,7 @@ def get_logits_warper_patch(self, generation_config):
generation_config.temperature = float(generation_config.temperature)
temperature = generation_config.temperature
if generation_config.dynamic_temperature:
if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0:
# Make sure TemperatureLogitsWarper will be created by temporarily
# setting temperature to a value != 1.
generation_config.temperature = 1.1
@ -294,12 +314,13 @@ def get_logits_warper_patch(self, generation_config):
warpers = self._get_logits_warper_old(generation_config)
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperWithDynatemp(
warpers[i] = ModifiedTemperatureLogitsWarper(
temperature,
generation_config.dynamic_temperature,
generation_config.dynatemp_low,
generation_config.dynatemp_high,
generation_config.dynatemp_exponent
generation_config.dynatemp_exponent,
generation_config.smoothing_factor
)
warpers_to_add = LogitsProcessorList()
@ -328,7 +349,7 @@ def get_logits_warper_patch(self, generation_config):
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']:
temperature_idx = i
break
@ -352,8 +373,7 @@ def get_logits_processor_patch(self, **kwargs):
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
if do_rep_pen_hijack:
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
result = self._get_logits_processor_old(**kwargs)
@ -372,6 +392,7 @@ def generation_config_init_patch(self, **kwargs):
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)

View File

@ -285,8 +285,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
if k in state:
generate_params[k] = state[k]
if state['negative_prompt'] != '':
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])

View File

@ -120,6 +120,7 @@ def list_interface_input_elements():
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'top_p',
'min_p',
'top_k',

View File

@ -49,6 +49,7 @@ def create_ui(default_preset):
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Replaces temperature with Quadratic Sampling.')
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])