mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Update truncation length based on max_seq_len/n_ctx
This commit is contained in:
parent
e6eda5c2da
commit
0c9e818bb8
@ -145,12 +145,14 @@ def create_event_handlers():
|
|||||||
apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then(
|
apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then(
|
||||||
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
||||||
update_model_parameters, gradio('interface_state'), None).then(
|
update_model_parameters, gradio('interface_state'), None).then(
|
||||||
load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False)
|
load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False).success(
|
||||||
|
update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length'))
|
||||||
|
|
||||||
shared.gradio['load_model'].click(
|
shared.gradio['load_model'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
update_model_parameters, gradio('interface_state'), None).then(
|
update_model_parameters, gradio('interface_state'), None).then(
|
||||||
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
|
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success(
|
||||||
|
update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length'))
|
||||||
|
|
||||||
shared.gradio['unload_model'].click(
|
shared.gradio['unload_model'].click(
|
||||||
unload_model, None, None).then(
|
unload_model, None, None).then(
|
||||||
@ -160,7 +162,8 @@ def create_event_handlers():
|
|||||||
unload_model, None, None).then(
|
unload_model, None, None).then(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
update_model_parameters, gradio('interface_state'), None).then(
|
update_model_parameters, gradio('interface_state'), None).then(
|
||||||
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
|
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success(
|
||||||
|
update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length'))
|
||||||
|
|
||||||
shared.gradio['save_model_settings'].click(
|
shared.gradio['save_model_settings'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
@ -235,3 +238,12 @@ def download_model_wrapper(repo_id, progress=gr.Progress()):
|
|||||||
except:
|
except:
|
||||||
progress(1.0)
|
progress(1.0)
|
||||||
yield traceback.format_exc().replace('\n', '\n\n')
|
yield traceback.format_exc().replace('\n', '\n\n')
|
||||||
|
|
||||||
|
|
||||||
|
def update_truncation_length(current_length, state):
|
||||||
|
if state['loader'] in ['ExLlama', 'ExLlama_HF']:
|
||||||
|
return state['max_seq_len']
|
||||||
|
elif state['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||||
|
return state['n_ctx']
|
||||||
|
else:
|
||||||
|
return current_length
|
||||||
|
@ -113,7 +113,7 @@ def create_ui(default_preset):
|
|||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
|
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
||||||
@ -129,3 +129,12 @@ def create_ui(default_preset):
|
|||||||
def create_event_handlers():
|
def create_event_handlers():
|
||||||
shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False)
|
shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False)
|
||||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
|
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
|
||||||
|
|
||||||
|
|
||||||
|
def get_truncation_length():
|
||||||
|
if shared.args.max_seq_len != shared.args_defaults.max_seq_len:
|
||||||
|
return shared.args.max_seq_len
|
||||||
|
if shared.args.n_ctx != shared.args_defaults.n_ctx:
|
||||||
|
return shared.args.n_ctx
|
||||||
|
else:
|
||||||
|
return shared.settings['truncation_length']
|
||||||
|
Loading…
Reference in New Issue
Block a user