diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index 14f6f9d6..493661c2 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -20,6 +20,7 @@ async def run(user_input, history): request = { 'user_input': user_input, 'max_new_tokens': 250, + 'auto_max_new_tokens': False, 'history': history, 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'character': 'Example', diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index 0e155c63..31641815 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -14,6 +14,7 @@ def run(user_input, history): request = { 'user_input': user_input, 'max_new_tokens': 250, + 'auto_max_new_tokens': False, 'history': history, 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'character': 'Example', diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index 1ae5a91c..175275f9 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -20,6 +20,7 @@ async def run(context): request = { 'prompt': context, 'max_new_tokens': 250, + 'auto_max_new_tokens': False, # Generation params. If 'preset' is set to different than 'None', the values # in presets/preset-name.yaml are used instead of the individual numbers. diff --git a/api-examples/api-example.py b/api-examples/api-example.py index 4e45de9e..7f8bc1d2 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -12,6 +12,7 @@ def run(prompt): request = { 'prompt': prompt, 'max_new_tokens': 250, + 'auto_max_new_tokens': False, # Generation params. If 'preset' is set to different than 'None', the values # in presets/preset-name.yaml are used instead of the individual numbers. diff --git a/extensions/api/util.py b/extensions/api/util.py index 2358b7d2..5cc259db 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -21,6 +21,7 @@ def build_parameters(body, chat=False): generate_params = { 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), + 'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)), 'do_sample': bool(body.get('do_sample', True)), 'temperature': float(body.get('temperature', 0.5)), 'top_p': float(body.get('top_p', 1)), diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index 52f0d641..cb8308e7 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -4,6 +4,7 @@ import copy # Data type is important, Ex. use 0.0 for a float 0 default_req_params = { 'max_new_tokens': 16, # 'Inf' for chat + 'auto_max_new_tokens': False, 'temperature': 1.0, 'top_p': 1.0, 'top_k': 1, # choose 20 for chat in absence of another default diff --git a/modules/loaders.py b/modules/loaders.py index 6d0291bf..838ecc86 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -116,6 +116,7 @@ loaders_samplers = { 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', + 'auto_max_new_tokens', }, 'ExLlama_HF': { 'temperature', @@ -139,6 +140,7 @@ loaders_samplers = { 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', + 'auto_max_new_tokens', }, 'ExLlama': { 'temperature', @@ -176,6 +178,7 @@ loaders_samplers = { 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', + 'auto_max_new_tokens', }, 'GPTQ-for-LLaMa': { 'temperature', @@ -203,6 +206,7 @@ loaders_samplers = { 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', + 'auto_max_new_tokens', }, 'llama.cpp': { 'temperature', @@ -237,6 +241,7 @@ loaders_samplers = { 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', + 'auto_max_new_tokens', }, } diff --git a/modules/shared.py b/modules/shared.py index 59d49ab6..a2782e65 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -36,6 +36,7 @@ settings = { 'max_new_tokens': 200, 'max_new_tokens_min': 1, 'max_new_tokens_max': 4096, + 'auto_max_new_tokens': False, 'seed': -1, 'character': 'None', 'name1': 'You', diff --git a/modules/text_generation.py b/modules/text_generation.py index e1be6aa3..f6f71990 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -247,6 +247,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) output = input_ids[0] cuda = not any((shared.args.cpu, shared.args.deepspeed)) + if state['auto_max_new_tokens']: + generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1] # Add the encoded tokens to generate_params question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) diff --git a/modules/ui.py b/modules/ui.py index d9b3a131..fe3482d2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -79,6 +79,7 @@ def list_model_elements(): def list_interface_input_elements(): elements = [ 'max_new_tokens', + 'auto_max_new_tokens', 'seed', 'temperature', 'top_p', diff --git a/server.py b/server.py index 0f1b9332..d622cdbe 100644 --- a/server.py +++ b/server.py @@ -425,6 +425,7 @@ def create_settings_menus(default_preset): shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"') with gr.Column(): + shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') diff --git a/settings-template.yaml b/settings-template.yaml index 3d6585d3..62e86371 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -3,6 +3,7 @@ autoload_model: false max_new_tokens: 200 max_new_tokens_min: 1 max_new_tokens_max: 4096 +auto_max_new_tokens: false seed: -1 character: None name1: You