From 8ac36369663107a005432afb6cbf694e4789cd5c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 21 May 2023 15:11:57 -0300 Subject: [PATCH] Add epsilon_cutoff/eta_cutoff parameters (#2258) --- api-example-chat-stream.py | 2 ++ api-example-chat.py | 2 ++ api-example-stream.py | 2 ++ api-example.py | 2 ++ extensions/api/util.py | 2 ++ extensions/openai/script.py | 4 ++++ modules/text_generation.py | 4 ++++ modules/ui.py | 2 +- server.py | 25 +++++++++++++++++-------- 9 files changed, 36 insertions(+), 9 deletions(-) diff --git a/api-example-chat-stream.py b/api-example-chat-stream.py index fb048c60..55cd706d 100644 --- a/api-example-chat-stream.py +++ b/api-example-chat-stream.py @@ -36,6 +36,8 @@ async def run(user_input, history): 'temperature': 0.7, 'top_p': 0.1, 'typical_p': 1, + 'epsilon_cutoff': 0, # In units of 1e-4 + 'eta_cutoff': 0, # In units of 1e-4 'repetition_penalty': 1.18, 'top_k': 40, 'min_length': 0, diff --git a/api-example-chat.py b/api-example-chat.py index 7da92e6c..3a98008e 100644 --- a/api-example-chat.py +++ b/api-example-chat.py @@ -30,6 +30,8 @@ def run(user_input, history): 'temperature': 0.7, 'top_p': 0.1, 'typical_p': 1, + 'epsilon_cutoff': 0, # In units of 1e-4 + 'eta_cutoff': 0, # In units of 1e-4 'repetition_penalty': 1.18, 'top_k': 40, 'min_length': 0, diff --git a/api-example-stream.py b/api-example-stream.py index ad8f7bf8..326bac45 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -24,6 +24,8 @@ async def run(context): 'temperature': 1.3, 'top_p': 0.1, 'typical_p': 1, + 'epsilon_cutoff': 0, # In units of 1e-4 + 'eta_cutoff': 0, # In units of 1e-4 'repetition_penalty': 1.18, 'top_k': 40, 'min_length': 0, diff --git a/api-example.py b/api-example.py index f35ea1db..45c2864d 100644 --- a/api-example.py +++ b/api-example.py @@ -16,6 +16,8 @@ def run(prompt): 'temperature': 1.3, 'top_p': 0.1, 'typical_p': 1, + 'epsilon_cutoff': 0, # In units of 1e-4 + 'eta_cutoff': 0, # In units of 1e-4 'repetition_penalty': 1.18, 'top_k': 40, 'min_length': 0, diff --git a/extensions/api/util.py b/extensions/api/util.py index 369381e3..27c6d7e9 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -15,6 +15,8 @@ def build_parameters(body, chat=False): 'temperature': float(body.get('temperature', 0.5)), 'top_p': float(body.get('top_p', 1)), 'typical_p': float(body.get('typical_p', body.get('typical', 1))), + 'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)), + 'eta_cutoff': float(body.get('eta_cutoff', 0)), 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), 'top_k': int(body.get('top_k', 0)), diff --git a/extensions/openai/script.py b/extensions/openai/script.py index bdababb9..9e560d94 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -208,6 +208,8 @@ class Handler(BaseHTTPRequestHandler): 'add_bos_token': shared.settings.get('add_bos_token', True), 'do_sample': True, 'typical_p': 1.0, + 'epsilon_cutoff': 0, # In units of 1e-4 + 'eta_cutoff': 0, # In units of 1e-4 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, @@ -516,6 +518,8 @@ class Handler(BaseHTTPRequestHandler): 'add_bos_token': shared.settings.get('add_bos_token', True), 'do_sample': True, 'typical_p': 1.0, + 'epsilon_cutoff': 0, # In units of 1e-4 + 'eta_cutoff': 0, # In units of 1e-4 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, diff --git a/modules/text_generation.py b/modules/text_generation.py index 2aa4f71b..253bd302 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -190,6 +190,10 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: generate_params[k] = state[k] + for k in ['epsilon_cutoff', 'eta_cutoff']: + if state[k] > 0: + generate_params[k] = state[k] * 1e-4 + if state['ban_eos_token']: generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] diff --git a/modules/ui.py b/modules/ui.py index a3e9c2ca..702ec99c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,7 +37,7 @@ def list_model_elements(): def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream'] + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream'] if chat: elements += ['name1', 'name2', 'greeting', 'context', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command'] diff --git a/server.py b/server.py index 669b2d62..80b0334a 100644 --- a/server.py +++ b/server.py @@ -84,6 +84,8 @@ def load_preset_values(preset_menu, state, return_dict=False): 'temperature': 1, 'top_p': 1, 'typical_p': 1, + 'epsilon_cutoff': 0, + 'eta_cutoff': 0, 'repetition_penalty': 1, 'encoder_repetition_penalty': 1, 'top_k': 50, @@ -100,13 +102,13 @@ def load_preset_values(preset_menu, state, return_dict=False): i = i.rstrip(',').strip().split('=') if len(i) == 2 and i[0].strip() != 'tokens': generate_params[i[0].strip()] = eval(i[1].strip()) - generate_params['temperature'] = min(1.99, generate_params['temperature']) + generate_params['temperature'] = min(1.99, generate_params['temperature']) if return_dict: return generate_params else: state.update(generate_params) - return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] def upload_soft_prompt(file): @@ -453,17 +455,24 @@ def create_settings_menus(default_preset): shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') with gr.Column(): with gr.Box(): - gr.Markdown('Contrastive search') - shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') - - gr.Markdown('Beam search (uses a lot of VRAM)') with gr.Row(): with gr.Column(): + gr.Markdown('Contrastive search') + shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') + + gr.Markdown('Beam search (uses a lot of VRAM)') shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') - with gr.Column(): shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') + with gr.Column(): + gr.Markdown('Other') + with gr.Row(): + with gr.Column(): + shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff', info='In units of 1e-4') + with gr.Column(): + shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff', info='In units of 1e-4') + with gr.Box(): with gr.Row(): with gr.Column(): @@ -485,7 +494,7 @@ def create_settings_menus(default_preset): with gr.Row(): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) - shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) + shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])