mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Make the bos token optional
This commit is contained in:
parent
4961f43702
commit
bd04ff27ad
@ -35,6 +35,7 @@ settings = {
|
||||
'greeting': 'Hello there!',
|
||||
'end_of_turn': '',
|
||||
'stop_at_newline': False,
|
||||
'add_bos_token': True,
|
||||
'chat_prompt_size': 2048,
|
||||
'chat_prompt_size_min': 0,
|
||||
'chat_prompt_size_max': 2048,
|
||||
|
@ -22,7 +22,7 @@ def get_max_prompt_length(tokens):
|
||||
return max_length
|
||||
|
||||
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True):
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
@ -30,6 +30,12 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
else:
|
||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||
|
||||
# This is a hack for making replies more creative.
|
||||
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
# Llama adds this extra token when the first character is '\n', and this
|
||||
# compromises the stopping criteria, so we just remove it
|
||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
@ -158,7 +164,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
input_ids = encode(question, generate_state['max_new_tokens'])
|
||||
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
|
@ -233,7 +233,7 @@ def create_model_menus():
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts', 'add_bos_token']:
|
||||
generate_params[k] = shared.settings[k]
|
||||
shared.gradio['generate_state'] = gr.State(generate_params)
|
||||
|
||||
@ -273,6 +273,7 @@ def create_settings_menus(default_preset):
|
||||
with gr.Column():
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||
|
||||
with gr.Accordion('Soft prompt', open=False):
|
||||
with gr.Row():
|
||||
@ -610,7 +611,7 @@ def create_interface():
|
||||
d[key] = value
|
||||
return d
|
||||
|
||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
|
||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
|
||||
if k not in shared.gradio:
|
||||
continue
|
||||
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
|
||||
|
Loading…
Reference in New Issue
Block a user