mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-23 01:59:21 +01:00
Add repetition penalty range parameter to transformers (#2916)
This commit is contained in:
parent
c6cae106e7
commit
3443219cbc
@ -44,6 +44,7 @@ async def run(user_input, history):
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
|
@ -38,6 +38,7 @@ def run(user_input, history):
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
|
@ -33,6 +33,7 @@ async def run(context):
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
|
@ -25,6 +25,7 @@ def run(prompt):
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
|
@ -21,6 +21,7 @@ def build_parameters(body, chat=False):
|
||||
'tfs': float(body.get('tfs', 1)),
|
||||
'top_a': float(body.get('top_a', 0)),
|
||||
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
|
||||
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
|
||||
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
|
@ -29,6 +29,7 @@ default_req_params = {
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'suffix': None,
|
||||
'stream': False,
|
||||
|
@ -71,6 +71,7 @@ class ExllamaModel:
|
||||
self.generator.settings.top_k = state['top_k']
|
||||
self.generator.settings.typical = state['typical_p']
|
||||
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
self.generator.settings.token_repetition_penalty_sustain = state['repetition_penalty_range']
|
||||
if state['ban_eos_token']:
|
||||
self.generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
else:
|
||||
|
@ -15,6 +15,7 @@ def load_preset(name):
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'num_beams': 1,
|
||||
@ -46,9 +47,9 @@ def load_preset_memoized(name):
|
||||
def load_preset_for_ui(name, state):
|
||||
generate_params = load_preset(name)
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
return yaml.dump(data, sort_keys=False)
|
||||
|
@ -5,6 +5,7 @@ import transformers
|
||||
from transformers import LogitsWarper
|
||||
from transformers.generation.logits_process import (
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper
|
||||
)
|
||||
@ -121,6 +122,29 @@ class MirostatLogitsWarper(LogitsWarper):
|
||||
return scores
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
'''
|
||||
Copied from the transformers library
|
||||
'''
|
||||
def __init__(self, penalty: float, _range: int):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
||||
self.penalty = penalty
|
||||
self._range = _range
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
input_ids = input_ids[:, -self._range:]
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_warper_patch(self, generation_config):
|
||||
warpers = self._get_logits_warper_old(generation_config)
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
@ -146,6 +170,19 @@ def get_logits_warper_patch(self, generation_config):
|
||||
return warpers
|
||||
|
||||
|
||||
def get_logits_processor_patch(self, **kwargs):
|
||||
result = self._get_logits_processor_old(**kwargs)
|
||||
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
||||
repetition_penalty = kwargs['generation_config'].repetition_penalty
|
||||
|
||||
if repetition_penalty_range > 0:
|
||||
for i in range(len(result)):
|
||||
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
|
||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def generation_config_init_patch(self, **kwargs):
|
||||
self.__init___old(**kwargs)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
@ -153,11 +190,15 @@ def generation_config_init_patch(self, **kwargs):
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
|
||||
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
|
||||
|
||||
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
|
||||
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
|
||||
|
||||
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
|
||||
transformers.GenerationConfig.__init__ = generation_config_init_patch
|
||||
|
@ -230,7 +230,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
for k in ['epsilon_cutoff', 'eta_cutoff']:
|
||||
|
@ -38,7 +38,7 @@ def list_model_elements():
|
||||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a']
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command']
|
||||
|
||||
|
@ -336,6 +336,7 @@ def create_settings_menus(default_preset):
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.')
|
||||
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range', info='The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.')
|
||||
@ -376,7 +377,7 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||
|
||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||
|
||||
|
||||
def create_file_saving_menus():
|
||||
|
Loading…
Reference in New Issue
Block a user