mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 01:09:22 +01:00
Add max_tokens_second param (#3533)
This commit is contained in:
parent
fe1f7c6513
commit
cec8db52e5
@ -22,6 +22,7 @@ async def run(user_input, history):
|
||||
'user_input': user_input,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
|
@ -16,6 +16,7 @@ def run(user_input, history):
|
||||
'user_input': user_input,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
|
@ -21,6 +21,7 @@ async def run(context):
|
||||
'prompt': context,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
|
@ -13,6 +13,7 @@ def run(prompt):
|
||||
'prompt': prompt,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
|
@ -22,6 +22,7 @@ def build_parameters(body, chat=False):
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
|
||||
'max_tokens_second': int(body.get('max_tokens_second', 0)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
|
@ -5,6 +5,7 @@ import copy
|
||||
default_req_params = {
|
||||
'max_new_tokens': 16, # 'Inf' for chat
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
'temperature': 1.0,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1, # choose 20 for chat in absence of another default
|
||||
|
@ -47,6 +47,7 @@ settings = {
|
||||
'truncation_length_max': 16384,
|
||||
'custom_stopping_strings': '',
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
'ban_eos_token': False,
|
||||
'add_bos_token': True,
|
||||
'skip_special_tokens': True,
|
||||
|
@ -80,10 +80,22 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||
if is_stream:
|
||||
cur_time = time.time()
|
||||
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
||||
last_update = cur_time
|
||||
|
||||
# Maximum number of tokens/second
|
||||
if state['max_tokens_second'] > 0:
|
||||
diff = 1 / state['max_tokens_second'] - (cur_time - last_update)
|
||||
if diff > 0:
|
||||
time.sleep(diff)
|
||||
|
||||
last_update = time.time()
|
||||
yield reply
|
||||
|
||||
# Limit updates to 24 per second to not stress low latency networks
|
||||
else:
|
||||
if cur_time - last_update > 0.041666666666666664:
|
||||
last_update = cur_time
|
||||
yield reply
|
||||
|
||||
if stop_found:
|
||||
break
|
||||
|
||||
|
@ -93,6 +93,7 @@ def list_interface_input_elements():
|
||||
elements = [
|
||||
'max_new_tokens',
|
||||
'auto_max_new_tokens',
|
||||
'max_tokens_second',
|
||||
'seed',
|
||||
'temperature',
|
||||
'top_p',
|
||||
|
@ -105,7 +105,6 @@ def create_ui(default_preset):
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.')
|
||||
|
||||
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
@ -114,6 +113,7 @@ def create_ui(default_preset):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||
shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum number of tokens/second', info='To make text readable in real time.')
|
||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
|
||||
with gr.Column():
|
||||
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
||||
|
@ -17,6 +17,7 @@ truncation_length_min: 0
|
||||
truncation_length_max: 16384
|
||||
custom_stopping_strings: ''
|
||||
auto_max_new_tokens: false
|
||||
max_tokens_second: 0
|
||||
ban_eos_token: false
|
||||
add_bos_token: true
|
||||
skip_special_tokens: true
|
||||
|
Loading…
Reference in New Issue
Block a user