mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add auto_max_new_tokens parameter (#3419)
This commit is contained in:
parent
0d9932815c
commit
e931844fe2
@ -20,6 +20,7 @@ async def run(user_input, history):
|
||||
request = {
|
||||
'user_input': user_input,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
|
@ -14,6 +14,7 @@ def run(user_input, history):
|
||||
request = {
|
||||
'user_input': user_input,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
|
@ -20,6 +20,7 @@ async def run(context):
|
||||
request = {
|
||||
'prompt': context,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
|
@ -12,6 +12,7 @@ def run(prompt):
|
||||
request = {
|
||||
'prompt': prompt,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
|
@ -21,6 +21,7 @@ def build_parameters(body, chat=False):
|
||||
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
|
@ -4,6 +4,7 @@ import copy
|
||||
# Data type is important, Ex. use 0.0 for a float 0
|
||||
default_req_params = {
|
||||
'max_new_tokens': 16, # 'Inf' for chat
|
||||
'auto_max_new_tokens': False,
|
||||
'temperature': 1.0,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1, # choose 20 for chat in absence of another default
|
||||
|
@ -116,6 +116,7 @@ loaders_samplers = {
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'ExLlama_HF': {
|
||||
'temperature',
|
||||
@ -139,6 +140,7 @@ loaders_samplers = {
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'ExLlama': {
|
||||
'temperature',
|
||||
@ -176,6 +178,7 @@ loaders_samplers = {
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'GPTQ-for-LLaMa': {
|
||||
'temperature',
|
||||
@ -203,6 +206,7 @@ loaders_samplers = {
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'llama.cpp': {
|
||||
'temperature',
|
||||
@ -237,6 +241,7 @@ loaders_samplers = {
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -36,6 +36,7 @@ settings = {
|
||||
'max_new_tokens': 200,
|
||||
'max_new_tokens_min': 1,
|
||||
'max_new_tokens_max': 4096,
|
||||
'auto_max_new_tokens': False,
|
||||
'seed': -1,
|
||||
'character': 'None',
|
||||
'name1': 'You',
|
||||
|
@ -247,6 +247,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
||||
if state['auto_max_new_tokens']:
|
||||
generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1]
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
|
@ -79,6 +79,7 @@ def list_model_elements():
|
||||
def list_interface_input_elements():
|
||||
elements = [
|
||||
'max_new_tokens',
|
||||
'auto_max_new_tokens',
|
||||
'seed',
|
||||
'temperature',
|
||||
'top_p',
|
||||
|
@ -425,6 +425,7 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
||||
with gr.Column():
|
||||
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||
|
||||
|
@ -3,6 +3,7 @@ autoload_model: false
|
||||
max_new_tokens: 200
|
||||
max_new_tokens_min: 1
|
||||
max_new_tokens_max: 4096
|
||||
auto_max_new_tokens: false
|
||||
seed: -1
|
||||
character: None
|
||||
name1: You
|
||||
|
Loading…
Reference in New Issue
Block a user