mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 15:48:04 +01:00
Cubic sampling w/ curve param (#5551)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
3168644152
commit
cfb25c9b3f
@ -13,6 +13,7 @@ class GenerationOptions(BaseModel):
|
||||
dynatemp_high: float = 1
|
||||
dynatemp_exponent: float = 1
|
||||
smoothing_factor: float = 0
|
||||
smoothing_curve: float = 1
|
||||
top_k: int = 0
|
||||
repetition_penalty: float = 1
|
||||
repetition_penalty_range: int = 1024
|
||||
|
@ -164,6 +164,7 @@ def transformers_samplers():
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'smoothing_curve',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
@ -240,6 +241,7 @@ loaders_samplers = {
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'smoothing_curve',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
@ -298,6 +300,7 @@ loaders_samplers = {
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'smoothing_curve',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
|
@ -19,6 +19,7 @@ def default_preset():
|
||||
'dynatemp_high': 1,
|
||||
'dynatemp_exponent': 1,
|
||||
'smoothing_factor': 0,
|
||||
'smoothing_curve': 1,
|
||||
'top_p': 1,
|
||||
'min_p': 0,
|
||||
'top_k': 0,
|
||||
@ -109,7 +110,7 @@ def random_preset(state):
|
||||
[1, 2],
|
||||
[1, 5]
|
||||
],
|
||||
'smoothing_factor': [0.2, 0.3, 0.6, 1.2]
|
||||
'smoothing_factor': [0.2, 0.3, 0.6, 1.2],
|
||||
},
|
||||
'repetition': {
|
||||
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
||||
|
@ -99,22 +99,27 @@ class DynamicTemperatureLogitsWarper(LogitsWarper):
|
||||
|
||||
class QuadraticSamplingLogitsWarper(LogitsWarper):
|
||||
'''
|
||||
Quadratic sampling.
|
||||
Quadratic sampling with smoothing factor and smoothing curve parameters.
|
||||
'''
|
||||
|
||||
def __init__(self, smoothing_factor: float):
|
||||
def __init__(self, smoothing_factor, smoothing_curve):
|
||||
self.smoothing_factor = smoothing_factor
|
||||
self.smoothing_curve = smoothing_curve
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Compute the maximum logit value
|
||||
|
||||
# Compute necessary values
|
||||
max_logit = scores.max()
|
||||
diff = scores - max_logit
|
||||
k = (3 - self.smoothing_curve) / 2
|
||||
s = (self.smoothing_curve - 1) / 2
|
||||
|
||||
# Apply the quadratic transformation
|
||||
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
|
||||
|
||||
# No need to print the top 5 logits since this is not required
|
||||
# print("Original top 5 logits: ", torch.topk(scores, 5))
|
||||
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
|
||||
# Apply transformation to non-negative infinity values
|
||||
transformed_logits = torch.where(
|
||||
scores != float('-inf'),
|
||||
-(k * self.smoothing_factor * diff**2) + (s * self.smoothing_factor * diff**3) + max_logit,
|
||||
scores
|
||||
)
|
||||
|
||||
return transformed_logits
|
||||
|
||||
@ -367,7 +372,8 @@ def get_logits_warper_patch(self, generation_config):
|
||||
if generation_config.smoothing_factor > 0:
|
||||
warpers_to_add.append(
|
||||
QuadraticSamplingLogitsWarper(
|
||||
smoothing_factor=generation_config.smoothing_factor
|
||||
smoothing_factor=generation_config.smoothing_factor,
|
||||
smoothing_curve=generation_config.smoothing_curve
|
||||
)
|
||||
)
|
||||
|
||||
@ -468,6 +474,7 @@ def generation_config_init_patch(self, **kwargs):
|
||||
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
|
||||
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
|
||||
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
|
||||
self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
|
@ -286,7 +286,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
|
||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
|
||||
if k in state:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
|
@ -123,6 +123,7 @@ def list_interface_input_elements():
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'smoothing_curve',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
|
@ -50,6 +50,7 @@ def create_ui(default_preset):
|
||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Activates Quadratic Sampling.')
|
||||
shared.gradio['smoothing_curve'] = gr.Slider(1.0, 10.0, value=generate_params['smoothing_curve'], step=0.01, label='smoothing_curve', info='Adjusts the dropoff curve of Quadratic Sampling.')
|
||||
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
|
||||
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
|
||||
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])
|
||||
|
Loading…
Reference in New Issue
Block a user