mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Add customizable ban tokens (#3899)
This commit is contained in:
parent
fb864dad7b
commit
f01b9aa71f
@ -70,6 +70,7 @@ async def run(user_input, history):
|
|||||||
'add_bos_token': True,
|
'add_bos_token': True,
|
||||||
'truncation_length': 2048,
|
'truncation_length': 2048,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
|
'custom_token_bans': '',
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'stopping_strings': []
|
'stopping_strings': []
|
||||||
}
|
}
|
||||||
|
@ -64,6 +64,7 @@ def run(user_input, history):
|
|||||||
'add_bos_token': True,
|
'add_bos_token': True,
|
||||||
'truncation_length': 2048,
|
'truncation_length': 2048,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
|
'custom_token_bans': '',
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'stopping_strings': []
|
'stopping_strings': []
|
||||||
}
|
}
|
||||||
|
@ -53,6 +53,7 @@ async def run(context):
|
|||||||
'add_bos_token': True,
|
'add_bos_token': True,
|
||||||
'truncation_length': 2048,
|
'truncation_length': 2048,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
|
'custom_token_bans': '',
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'stopping_strings': []
|
'stopping_strings': []
|
||||||
}
|
}
|
||||||
|
@ -45,6 +45,7 @@ def run(prompt):
|
|||||||
'add_bos_token': True,
|
'add_bos_token': True,
|
||||||
'truncation_length': 2048,
|
'truncation_length': 2048,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
|
'custom_token_bans': '',
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'stopping_strings': []
|
'stopping_strings': []
|
||||||
}
|
}
|
||||||
|
@ -49,6 +49,7 @@ def build_parameters(body, chat=False):
|
|||||||
'seed': int(body.get('seed', -1)),
|
'seed': int(body.get('seed', -1)),
|
||||||
'add_bos_token': bool(body.get('add_bos_token', True)),
|
'add_bos_token': bool(body.get('add_bos_token', True)),
|
||||||
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
||||||
|
'custom_token_bans': str(body.get('custom_token_bans', '')),
|
||||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||||
'custom_stopping_strings': '', # leave this blank
|
'custom_stopping_strings': '', # leave this blank
|
||||||
|
@ -37,6 +37,7 @@ default_req_params = {
|
|||||||
'guidance_scale': 1,
|
'guidance_scale': 1,
|
||||||
'negative_prompt': '',
|
'negative_prompt': '',
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
|
'custom_token_bans': '',
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'custom_stopping_strings': '',
|
'custom_stopping_strings': '',
|
||||||
# 'logits_processor' - conditionally passed
|
# 'logits_processor' - conditionally passed
|
||||||
|
@ -108,6 +108,11 @@ class ExllamaModel:
|
|||||||
else:
|
else:
|
||||||
self.generator.disallow_tokens(None)
|
self.generator.disallow_tokens(None)
|
||||||
|
|
||||||
|
if state['custom_token_bans']:
|
||||||
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||||
|
if len(to_ban) > 0:
|
||||||
|
self.generator.disallow_tokens(self.tokenizer, to_ban)
|
||||||
|
|
||||||
# Case 1: no CFG
|
# Case 1: no CFG
|
||||||
if state['guidance_scale'] == 1:
|
if state['guidance_scale'] == 1:
|
||||||
self.generator.end_beam_search()
|
self.generator.end_beam_search()
|
||||||
|
@ -30,7 +30,7 @@ class Exllamav2Model:
|
|||||||
config.max_seq_len = shared.args.max_seq_len
|
config.max_seq_len = shared.args.max_seq_len
|
||||||
config.scale_pos_emb = shared.args.compress_pos_emb
|
config.scale_pos_emb = shared.args.compress_pos_emb
|
||||||
config.scale_alpha_value = shared.args.alpha_value
|
config.scale_alpha_value = shared.args.alpha_value
|
||||||
|
|
||||||
model = ExLlamaV2(config)
|
model = ExLlamaV2(config)
|
||||||
|
|
||||||
split = None
|
split = None
|
||||||
@ -60,6 +60,11 @@ class Exllamav2Model:
|
|||||||
if state['ban_eos_token']:
|
if state['ban_eos_token']:
|
||||||
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||||
|
|
||||||
|
if state['custom_token_bans']:
|
||||||
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||||
|
if len(to_ban) > 0:
|
||||||
|
settings.disallow_tokens(self.tokenizer, to_ban)
|
||||||
|
|
||||||
ids = self.tokenizer.encode(prompt)
|
ids = self.tokenizer.encode(prompt)
|
||||||
ids = ids[:, -get_max_prompt_length(state):]
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
initial_len = ids.shape[-1]
|
initial_len = ids.shape[-1]
|
||||||
|
@ -31,6 +31,13 @@ def ban_eos_logits_processor(eos_token, input_ids, logits):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def custom_token_ban_logits_processor(token_ids, input_ids, logits):
|
||||||
|
for token_id in token_ids:
|
||||||
|
logits[token_id] = -float('inf')
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class LlamaCppModel:
|
class LlamaCppModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
@ -104,6 +111,15 @@ class LlamaCppModel:
|
|||||||
prompt = prompt[-get_max_prompt_length(state):]
|
prompt = prompt[-get_max_prompt_length(state):]
|
||||||
prompt = self.decode(prompt).decode('utf-8')
|
prompt = self.decode(prompt).decode('utf-8')
|
||||||
|
|
||||||
|
logit_processors = LogitsProcessorList()
|
||||||
|
if state['ban_eos_token']:
|
||||||
|
logit_processors.append(partial(ban_eos_logits_processor, self.model.tokenizer.eos_token_id))
|
||||||
|
|
||||||
|
if state['custom_token_bans']:
|
||||||
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||||
|
if len(to_ban) > 0:
|
||||||
|
logit_processors.append(partial(custom_token_ban_logits_processor, to_ban))
|
||||||
|
|
||||||
completion_chunks = self.model.create_completion(
|
completion_chunks = self.model.create_completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=state['max_new_tokens'],
|
max_tokens=state['max_new_tokens'],
|
||||||
@ -116,9 +132,7 @@ class LlamaCppModel:
|
|||||||
mirostat_tau=state['mirostat_tau'],
|
mirostat_tau=state['mirostat_tau'],
|
||||||
mirostat_eta=state['mirostat_eta'],
|
mirostat_eta=state['mirostat_eta'],
|
||||||
stream=True,
|
stream=True,
|
||||||
logits_processor=LogitsProcessorList([
|
logits_processor=logit_processors,
|
||||||
partial(ban_eos_logits_processor, self.model.token_eos()),
|
|
||||||
]) if state['ban_eos_token'] else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
|
@ -150,6 +150,7 @@ loaders_samplers = {
|
|||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
@ -176,6 +177,7 @@ loaders_samplers = {
|
|||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
@ -191,6 +193,7 @@ loaders_samplers = {
|
|||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
},
|
},
|
||||||
'ExLlamav2': {
|
'ExLlamav2': {
|
||||||
@ -201,6 +204,7 @@ loaders_samplers = {
|
|||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'seed',
|
'seed',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
},
|
},
|
||||||
'ExLlamav2_HF': {
|
'ExLlamav2_HF': {
|
||||||
@ -225,6 +229,7 @@ loaders_samplers = {
|
|||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
@ -255,6 +260,7 @@ loaders_samplers = {
|
|||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
@ -285,6 +291,7 @@ loaders_samplers = {
|
|||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
@ -299,6 +306,7 @@ loaders_samplers = {
|
|||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
},
|
},
|
||||||
'llamacpp_HF': {
|
'llamacpp_HF': {
|
||||||
'temperature',
|
'temperature',
|
||||||
@ -322,6 +330,7 @@ loaders_samplers = {
|
|||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'negative_prompt',
|
'negative_prompt',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
|
@ -28,6 +28,7 @@ def default_preset():
|
|||||||
'num_beams': 1,
|
'num_beams': 1,
|
||||||
'length_penalty': 1,
|
'length_penalty': 1,
|
||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
|
'custom_token_bans': '',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ settings = {
|
|||||||
'auto_max_new_tokens': False,
|
'auto_max_new_tokens': False,
|
||||||
'max_tokens_second': 0,
|
'max_tokens_second': 0,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
|
'custom_token_bans': '',
|
||||||
'add_bos_token': True,
|
'add_bos_token': True,
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'stream': True,
|
'stream': True,
|
||||||
|
@ -266,6 +266,14 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
if state['ban_eos_token']:
|
if state['ban_eos_token']:
|
||||||
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
if state['custom_token_bans']:
|
||||||
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||||
|
if len(to_ban) > 0:
|
||||||
|
if generate_params.get('suppress_tokens', None):
|
||||||
|
generate_params['suppress_tokens'] += to_ban
|
||||||
|
else:
|
||||||
|
generate_params['suppress_tokens'] = to_ban
|
||||||
|
|
||||||
generate_params.update({'use_cache': not shared.args.no_cache})
|
generate_params.update({'use_cache': not shared.args.no_cache})
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
generate_params.update({'synced_gpus': True})
|
generate_params.update({'synced_gpus': True})
|
||||||
|
@ -118,6 +118,7 @@ def list_interface_input_elements():
|
|||||||
'guidance_scale',
|
'guidance_scale',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'custom_token_bans',
|
||||||
'truncation_length',
|
'truncation_length',
|
||||||
'custom_stopping_strings',
|
'custom_stopping_strings',
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
|
@ -118,8 +118,8 @@ def create_ui(default_preset):
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
||||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||||
|
shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Custom token bans', info='Specific token IDs to ban from generating, comma-separated. The IDs can be found in a tokenizer.json file.')
|
||||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||||
|
|
||||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||||
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')
|
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ custom_stopping_strings: ''
|
|||||||
auto_max_new_tokens: false
|
auto_max_new_tokens: false
|
||||||
max_tokens_second: 0
|
max_tokens_second: 0
|
||||||
ban_eos_token: false
|
ban_eos_token: false
|
||||||
|
custom_token_bans: ''
|
||||||
add_bos_token: true
|
add_bos_token: true
|
||||||
skip_special_tokens: true
|
skip_special_tokens: true
|
||||||
stream: true
|
stream: true
|
||||||
|
Loading…
Reference in New Issue
Block a user