From b6077b02e45eb904aa6f9eca142d38b9ee6a46e8 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 3 Feb 2024 21:20:02 -0600 Subject: [PATCH] Quadratic sampling (#5403) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- docs/03 - Parameters Tab.md | 1 + extensions/openai/typing.py | 1 + modules/loaders.py | 3 +++ modules/presets.py | 1 + modules/sampler_hijack.py | 47 +++++++++++++++++++++++++++---------- modules/text_generation.py | 5 ++-- modules/ui.py | 1 + modules/ui_parameters.py | 1 + 8 files changed, 45 insertions(+), 15 deletions(-) diff --git a/docs/03 - Parameters Tab.md b/docs/03 - Parameters Tab.md index 501cf510..affa9e73 100644 --- a/docs/03 - Parameters Tab.md +++ b/docs/03 - Parameters Tab.md @@ -55,6 +55,7 @@ For more information about the parameters, the [transformers documentation](http * **mirostat_tau**: No idea, see the paper for details. According to the Preset Arena, 8 is a good value. * **mirostat_eta**: No idea, see the paper for details. According to the Preset Arena, 0.1 is a good value. * **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent". +* **smoothing_factor**: Activates Quadratic Sampling. This takes precedence over regular temperature and dynamic temperature, and replaces those samplers. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked. * **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency. * **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked). * **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (notably ExLlama v1 and v2). For these loaders, the seed has no effect. diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 9c4a04f0..3deb464f 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -12,6 +12,7 @@ class GenerationOptions(BaseModel): dynatemp_low: float = 1 dynatemp_high: float = 1 dynatemp_exponent: float = 1 + smoothing_factor: float = 0 top_k: int = 0 repetition_penalty: float = 1 repetition_penalty_range: int = 1024 diff --git a/modules/loaders.py b/modules/loaders.py index 2976a851..618d4502 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -159,6 +159,7 @@ def transformers_samplers(): 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', + 'smoothing_factor', 'top_p', 'min_p', 'top_k', @@ -233,6 +234,7 @@ loaders_samplers = { 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', + 'smoothing_factor', 'top_p', 'min_p', 'top_k', @@ -289,6 +291,7 @@ loaders_samplers = { 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', + 'smoothing_factor', 'top_p', 'min_p', 'top_k', diff --git a/modules/presets.py b/modules/presets.py index 5e686e34..966c706e 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -17,6 +17,7 @@ def default_preset(): 'dynatemp_low': 1, 'dynatemp_high': 1, 'dynatemp_exponent': 1, + 'smoothing_factor': 0, 'top_p': 1, 'min_p': 0, 'top_k': 0, diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index e9d82d3c..59b90b02 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -15,8 +15,12 @@ from modules import shared global_scores = None -class TemperatureLogitsWarperWithDynatemp(LogitsWarper): - def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float): +class ModifiedTemperatureLogitsWarper(LogitsWarper): + ''' + Based on the original Transformers temperature logits warper, this + adds support for dynamic temperature and quadratic sampling. + ''' + def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float): if not isinstance(temperature, float) or not (temperature > 0): except_msg = ( f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token " @@ -32,16 +36,27 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper): self.dynatemp_low = dynatemp_low self.dynatemp_high = dynatemp_high self.dynatemp_exponent = dynatemp_exponent + self.smoothing_factor = smoothing_factor def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - # Regular temperature - if not self.dynamic_temperature: - scores = scores / self.temperature - return scores + # Quadratic sampling + if self.smoothing_factor > 0: + + # Compute the maximum logit value + max_logit = scores.max() + + # Apply the quadratic transformation + transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit + + # No need to print the top 5 logits since this is not required + # print("Original top 5 logits: ", torch.topk(scores, 5)) + # print("New top 5 logits: ", torch.topk(transformed_logits, 5)) + + return transformed_logits # Dynamic temperature - else: + elif self.dynamic_temperature: min_temp = self.dynatemp_low max_temp = self.dynatemp_high exponent_val = self.dynatemp_exponent @@ -88,6 +103,11 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper): return scores + # Regular temperature + else: + scores = scores / self.temperature + return scores + class MinPLogitsWarper(LogitsWarper): def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -286,7 +306,7 @@ def get_logits_warper_patch(self, generation_config): generation_config.temperature = float(generation_config.temperature) temperature = generation_config.temperature - if generation_config.dynamic_temperature: + if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0: # Make sure TemperatureLogitsWarper will be created by temporarily # setting temperature to a value != 1. generation_config.temperature = 1.1 @@ -294,12 +314,13 @@ def get_logits_warper_patch(self, generation_config): warpers = self._get_logits_warper_old(generation_config) for i in range(len(warpers)): if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper': - warpers[i] = TemperatureLogitsWarperWithDynatemp( + warpers[i] = ModifiedTemperatureLogitsWarper( temperature, generation_config.dynamic_temperature, generation_config.dynatemp_low, generation_config.dynatemp_high, - generation_config.dynatemp_exponent + generation_config.dynatemp_exponent, + generation_config.smoothing_factor ) warpers_to_add = LogitsProcessorList() @@ -328,7 +349,7 @@ def get_logits_warper_patch(self, generation_config): if generation_config.temperature_last: temperature_idx = None for i in range(len(warpers)): - if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']: + if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']: temperature_idx = i break @@ -352,8 +373,7 @@ def get_logits_processor_patch(self, **kwargs): repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0) if do_rep_pen_hijack: - # Make sure that a RepetitionPenaltyLogitsProcessor will be created - kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1 + kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created result = self._get_logits_processor_old(**kwargs) @@ -372,6 +392,7 @@ def generation_config_init_patch(self, **kwargs): self.dynatemp_low = kwargs.pop("dynatemp_low", 1) self.dynatemp_high = kwargs.pop("dynatemp_high", 1) self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1) + self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0) self.tfs = kwargs.pop("tfs", 1.0) self.top_a = kwargs.pop("top_a", 0.0) self.mirostat_mode = kwargs.pop("mirostat_mode", 0) diff --git a/modules/text_generation.py b/modules/text_generation.py index f4849840..2796bfe1 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -285,8 +285,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']: - generate_params[k] = state[k] + for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']: + if k in state: + generate_params[k] = state[k] if state['negative_prompt'] != '': generate_params['negative_prompt_ids'] = encode(state['negative_prompt']) diff --git a/modules/ui.py b/modules/ui.py index b639c4df..53a8fd14 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -120,6 +120,7 @@ def list_interface_input_elements(): 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', + 'smoothing_factor', 'top_p', 'min_p', 'top_k', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 63a3743a..a81ed27a 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -49,6 +49,7 @@ def create_ui(default_preset): shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.') shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') + shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Replaces temperature with Quadratic Sampling.') shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature') shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature']) shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])