diff --git a/modules/shared.py b/modules/shared.py index e26489ee..62cd20d3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -35,6 +35,7 @@ settings = { 'greeting': 'Hello there!', 'end_of_turn': '', 'stop_at_newline': False, + 'add_bos_token': True, 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, 'chat_prompt_size_max': 2048, diff --git a/modules/text_generation.py b/modules/text_generation.py index 8846eaff..40bdf838 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -22,7 +22,7 @@ def get_max_prompt_length(tokens): return max_length -def encode(prompt, tokens_to_generate=0, add_special_tokens=True): +def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True): if any((shared.is_RWKV, shared.is_llamacpp)): input_ids = shared.tokenizer.encode(str(prompt)) input_ids = np.array(input_ids).reshape(1, len(input_ids)) @@ -30,6 +30,12 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) + # This is a hack for making replies more creative. + if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id: + input_ids = input_ids[:, 1:] + + # Llama adds this extra token when the first character is '\n', and this + # compromises the stopping criteria, so we just remove it if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: input_ids = input_ids[:, 1:] @@ -158,7 +164,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[] print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return - input_ids = encode(question, generate_state['max_new_tokens']) + input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token']) original_input_ids = input_ids output = input_ids[0] diff --git a/server.py b/server.py index 72c58d29..114584be 100644 --- a/server.py +++ b/server.py @@ -233,7 +233,7 @@ def create_model_menus(): def create_settings_menus(default_preset): generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) - for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']: + for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts', 'add_bos_token']: generate_params[k] = shared.settings[k] shared.gradio['generate_state'] = gr.State(generate_params) @@ -273,6 +273,7 @@ def create_settings_menus(default_preset): with gr.Column(): shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') + shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') with gr.Accordion('Soft prompt', open=False): with gr.Row(): @@ -610,7 +611,7 @@ def create_interface(): d[key] = value return d - for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']: + for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']: if k not in shared.gradio: continue if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]: