mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Add additive_repetition_penalty sampler setting. (#3627)
This commit is contained in:
parent
6086768309
commit
4440f87722
@ -52,6 +52,7 @@ async def run(user_input, history):
|
|||||||
'tfs': 1,
|
'tfs': 1,
|
||||||
'top_a': 0,
|
'top_a': 0,
|
||||||
'repetition_penalty': 1.18,
|
'repetition_penalty': 1.18,
|
||||||
|
'additive_repetition_penalty': 0,
|
||||||
'repetition_penalty_range': 0,
|
'repetition_penalty_range': 0,
|
||||||
'top_k': 40,
|
'top_k': 40,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
|
@ -46,6 +46,7 @@ def run(user_input, history):
|
|||||||
'tfs': 1,
|
'tfs': 1,
|
||||||
'top_a': 0,
|
'top_a': 0,
|
||||||
'repetition_penalty': 1.18,
|
'repetition_penalty': 1.18,
|
||||||
|
'additive_repetition_penalty': 0,
|
||||||
'repetition_penalty_range': 0,
|
'repetition_penalty_range': 0,
|
||||||
'top_k': 40,
|
'top_k': 40,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
|
@ -35,6 +35,7 @@ async def run(context):
|
|||||||
'tfs': 1,
|
'tfs': 1,
|
||||||
'top_a': 0,
|
'top_a': 0,
|
||||||
'repetition_penalty': 1.18,
|
'repetition_penalty': 1.18,
|
||||||
|
'additive_repetition_penalty': 0,
|
||||||
'repetition_penalty_range': 0,
|
'repetition_penalty_range': 0,
|
||||||
'top_k': 40,
|
'top_k': 40,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
|
@ -27,6 +27,7 @@ def run(prompt):
|
|||||||
'tfs': 1,
|
'tfs': 1,
|
||||||
'top_a': 0,
|
'top_a': 0,
|
||||||
'repetition_penalty': 1.18,
|
'repetition_penalty': 1.18,
|
||||||
|
'additive_repetition_penalty': 0,
|
||||||
'repetition_penalty_range': 0,
|
'repetition_penalty_range': 0,
|
||||||
'top_k': 40,
|
'top_k': 40,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
|
@ -35,6 +35,7 @@ For more information about the parameters, the [transformers documentation](http
|
|||||||
* **top_p**: If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.
|
* **top_p**: If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.
|
||||||
* **top_k**: Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.
|
* **top_k**: Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.
|
||||||
* **repetition_penalty**: Penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
|
* **repetition_penalty**: Penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
|
||||||
|
* **additive_repetition_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition.
|
||||||
* **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
|
* **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
|
||||||
* **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
|
* **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
|
||||||
* **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens.
|
* **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens.
|
||||||
|
@ -32,6 +32,7 @@ def build_parameters(body, chat=False):
|
|||||||
'tfs': float(body.get('tfs', 1)),
|
'tfs': float(body.get('tfs', 1)),
|
||||||
'top_a': float(body.get('top_a', 0)),
|
'top_a': float(body.get('top_a', 0)),
|
||||||
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
|
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
|
||||||
|
'additive_repetition_penalty': float(body.get('additive_repetition_penalty', body.get('additive_rep_pen', 0))),
|
||||||
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
|
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
|
||||||
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
|
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
|
||||||
'top_k': int(body.get('top_k', 0)),
|
'top_k': int(body.get('top_k', 0)),
|
||||||
|
@ -10,6 +10,7 @@ default_req_params = {
|
|||||||
'top_p': 1.0,
|
'top_p': 1.0,
|
||||||
'top_k': 1, # choose 20 for chat in absence of another default
|
'top_k': 1, # choose 20 for chat in absence of another default
|
||||||
'repetition_penalty': 1.18,
|
'repetition_penalty': 1.18,
|
||||||
|
'additive_repetition_penalty': 0,
|
||||||
'repetition_penalty_range': 0,
|
'repetition_penalty_range': 0,
|
||||||
'encoder_repetition_penalty': 1.0,
|
'encoder_repetition_penalty': 1.0,
|
||||||
'suffix': None,
|
'suffix': None,
|
||||||
|
@ -153,6 +153,7 @@ loaders_samplers = {
|
|||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
|
'additive_repetition_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'encoder_repetition_penalty',
|
'encoder_repetition_penalty',
|
||||||
'no_repeat_ngram_size',
|
'no_repeat_ngram_size',
|
||||||
@ -186,6 +187,7 @@ loaders_samplers = {
|
|||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
|
'additive_repetition_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'encoder_repetition_penalty',
|
'encoder_repetition_penalty',
|
||||||
'no_repeat_ngram_size',
|
'no_repeat_ngram_size',
|
||||||
@ -244,6 +246,7 @@ loaders_samplers = {
|
|||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
|
'additive_repetition_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'encoder_repetition_penalty',
|
'encoder_repetition_penalty',
|
||||||
'no_repeat_ngram_size',
|
'no_repeat_ngram_size',
|
||||||
@ -273,6 +276,7 @@ loaders_samplers = {
|
|||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
|
'additive_repetition_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'encoder_repetition_penalty',
|
'encoder_repetition_penalty',
|
||||||
'no_repeat_ngram_size',
|
'no_repeat_ngram_size',
|
||||||
@ -306,6 +310,7 @@ loaders_samplers = {
|
|||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
|
'additive_repetition_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'encoder_repetition_penalty',
|
'encoder_repetition_penalty',
|
||||||
'no_repeat_ngram_size',
|
'no_repeat_ngram_size',
|
||||||
@ -353,6 +358,7 @@ loaders_samplers = {
|
|||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
|
'additive_repetition_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'encoder_repetition_penalty',
|
'encoder_repetition_penalty',
|
||||||
'no_repeat_ngram_size',
|
'no_repeat_ngram_size',
|
||||||
@ -389,6 +395,7 @@ loaders_samplers = {
|
|||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
|
'additive_repetition_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'encoder_repetition_penalty',
|
'encoder_repetition_penalty',
|
||||||
'no_repeat_ngram_size',
|
'no_repeat_ngram_size',
|
||||||
|
@ -16,6 +16,7 @@ def default_preset():
|
|||||||
'tfs': 1,
|
'tfs': 1,
|
||||||
'top_a': 0,
|
'top_a': 0,
|
||||||
'repetition_penalty': 1,
|
'repetition_penalty': 1,
|
||||||
|
'additive_repetition_penalty': 0,
|
||||||
'repetition_penalty_range': 0,
|
'repetition_penalty_range': 0,
|
||||||
'encoder_repetition_penalty': 1,
|
'encoder_repetition_penalty': 1,
|
||||||
'no_repeat_ngram_size': 0,
|
'no_repeat_ngram_size': 0,
|
||||||
|
@ -139,11 +139,12 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
|||||||
Copied from the transformers library
|
Copied from the transformers library
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, penalty: float, _range: int):
|
def __init__(self, penalty: float, additive_penalty: float, _range: int):
|
||||||
if not isinstance(penalty, float) or not (penalty > 0):
|
if not (penalty > 0):
|
||||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
|
||||||
|
|
||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
|
self.additive_penalty = additive_penalty
|
||||||
self._range = _range
|
self._range = _range
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
@ -153,6 +154,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
|||||||
|
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
score -= self.additive_penalty
|
||||||
|
|
||||||
scores.scatter_(1, input_ids, score)
|
scores.scatter_(1, input_ids, score)
|
||||||
return scores
|
return scores
|
||||||
@ -185,14 +187,20 @@ def get_logits_warper_patch(self, generation_config):
|
|||||||
|
|
||||||
|
|
||||||
def get_logits_processor_patch(self, **kwargs):
|
def get_logits_processor_patch(self, **kwargs):
|
||||||
result = self._get_logits_processor_old(**kwargs)
|
|
||||||
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
|
||||||
repetition_penalty = kwargs['generation_config'].repetition_penalty
|
repetition_penalty = kwargs['generation_config'].repetition_penalty
|
||||||
|
additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty
|
||||||
|
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
||||||
|
do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0)
|
||||||
|
if do_rep_pen_hijack:
|
||||||
|
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
|
||||||
|
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
|
||||||
|
|
||||||
if repetition_penalty_range > 0:
|
result = self._get_logits_processor_old(**kwargs)
|
||||||
|
|
||||||
|
if do_rep_pen_hijack:
|
||||||
for i in range(len(result)):
|
for i in range(len(result)):
|
||||||
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
|
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
|
||||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
|
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -205,6 +213,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||||||
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||||
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||||
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
||||||
|
self.additive_repetition_penalty = kwargs.pop("additive_repetition_penalty", 0)
|
||||||
|
|
||||||
|
|
||||||
def hijack_samplers():
|
def hijack_samplers():
|
||||||
|
@ -273,7 +273,7 @@ def apply_stopping_strings(reply, all_stop_strings):
|
|||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
if state['negative_prompt'] != '':
|
if state['negative_prompt'] != '':
|
||||||
|
@ -105,6 +105,7 @@ def list_interface_input_elements():
|
|||||||
'epsilon_cutoff',
|
'epsilon_cutoff',
|
||||||
'eta_cutoff',
|
'eta_cutoff',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
|
'additive_repetition_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'encoder_repetition_penalty',
|
'encoder_repetition_penalty',
|
||||||
'no_repeat_ngram_size',
|
'no_repeat_ngram_size',
|
||||||
|
@ -31,6 +31,7 @@ def create_ui(default_preset):
|
|||||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
||||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
||||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
|
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
|
||||||
|
shared.gradio['additive_repetition_penalty'] = gr.Slider(0, 4, value=generate_params['additive_repetition_penalty'], step=0.05, label='additive_repetition_penalty')
|
||||||
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
|
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
|
||||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
||||||
shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')
|
shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')
|
||||||
|
Loading…
Reference in New Issue
Block a user