mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
New feature: "random preset" button (#4647)
This commit is contained in:
parent
d1a58da52f
commit
83b64e7fc1
@ -11,9 +11,13 @@ LLMs work by generating one token at a time. Given your prompt, the model calcul
|
|||||||
|
|
||||||
### Preset menu
|
### Preset menu
|
||||||
|
|
||||||
Can be used to save combinations of parameters for reuse.
|
Can be used to save and load combinations of parameters for reuse.
|
||||||
|
|
||||||
The built-in presets were not manually chosen. They were obtained after a blind contest called "Preset Arena" where hundreds of people voted. The full results can be found [here](https://github.com/oobabooga/oobabooga.github.io/blob/main/arena/results.md).
|
* **🎲 button**: creates a random yet interpretable preset. Only 1 parameter of each category is included for the categories: removing tail tokens, avoiding repetition, and flattening the distribution. That is, top_p and top_k are not mixed, and neither are repetition_penalty and frequency_penalty. You can use this button to break out of a loop of bad generations after multiple "Regenerate" attempts.
|
||||||
|
|
||||||
|
#### Built-in presets
|
||||||
|
|
||||||
|
These were obtained after a blind contest called "Preset Arena" where hundreds of people voted. The full results can be found [here](https://github.com/oobabooga/oobabooga.github.io/blob/main/arena/results.md).
|
||||||
|
|
||||||
A key takeaway is that the best presets are:
|
A key takeaway is that the best presets are:
|
||||||
|
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
import functools
|
import functools
|
||||||
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.loaders import loaders_samplers
|
||||||
|
|
||||||
|
|
||||||
def default_preset():
|
def default_preset():
|
||||||
return {
|
return {
|
||||||
@ -63,6 +67,45 @@ def load_preset_for_ui(name, state):
|
|||||||
return state, *[generate_params[k] for k in presets_params()]
|
return state, *[generate_params[k] for k in presets_params()]
|
||||||
|
|
||||||
|
|
||||||
|
def random_preset(state):
|
||||||
|
params_and_values = {
|
||||||
|
'remove_tail_tokens': {
|
||||||
|
'top_p': [0.5, 0.8, 0.9, 0.95, 0.99],
|
||||||
|
'min_p': [0.5, 0.2, 0.1, 0.05, 0.01],
|
||||||
|
'top_k': [3, 5, 10, 20, 30, 40],
|
||||||
|
'typical_p': [0.2, 0.575, 0.95],
|
||||||
|
'tfs': [0.5, 0.8, 0.9, 0.95, 0.99],
|
||||||
|
'top_a': [0.5, 0.2, 0.1, 0.05, 0.01],
|
||||||
|
'epsilon_cutoff': [1, 3, 5, 7, 9],
|
||||||
|
'eta_cutoff': [3, 6, 9, 12, 15, 18],
|
||||||
|
},
|
||||||
|
'flatten_distribution': {
|
||||||
|
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0],
|
||||||
|
},
|
||||||
|
'repetition': {
|
||||||
|
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
||||||
|
'presence_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
|
||||||
|
'frequency_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
|
||||||
|
},
|
||||||
|
'other': {
|
||||||
|
'temperature_last': [True, False],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
generate_params = default_preset()
|
||||||
|
for cat in params_and_values:
|
||||||
|
choices = list(params_and_values[cat].keys())
|
||||||
|
if shared.args.loader is not None:
|
||||||
|
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
|
||||||
|
|
||||||
|
if len(choices) > 0:
|
||||||
|
choice = random.choice(choices)
|
||||||
|
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
||||||
|
|
||||||
|
state.update(generate_params)
|
||||||
|
return state, *[generate_params[k] for k in presets_params()]
|
||||||
|
|
||||||
|
|
||||||
def generate_preset_yaml(state):
|
def generate_preset_yaml(state):
|
||||||
defaults = default_preset()
|
defaults = default_preset()
|
||||||
data = {k: state[k] for k in presets_params()}
|
data = {k: state[k] for k in presets_params()}
|
||||||
|
@ -18,6 +18,7 @@ def create_ui(default_preset):
|
|||||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button', interactive=not mu)
|
||||||
shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
|
shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
|
||||||
shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu)
|
shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu)
|
||||||
|
shared.gradio['random_preset'] = gr.Button('🎲', elem_classes='refresh-button')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['filter_by_loader'] = gr.Dropdown(label="Filter by loader", choices=["All"] + list(loaders.loaders_and_params.keys()), value="All", elem_classes='slim-dropdown')
|
shared.gradio['filter_by_loader'] = gr.Dropdown(label="Filter by loader", choices=["All"] + list(loaders.loaders_and_params.keys()), value="All", elem_classes='slim-dropdown')
|
||||||
@ -90,6 +91,7 @@ def create_ui(default_preset):
|
|||||||
def create_event_handlers():
|
def create_event_handlers():
|
||||||
shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False)
|
shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False)
|
||||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
|
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
|
||||||
|
shared.gradio['random_preset'].click(presets.random_preset, gradio('interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
|
||||||
shared.gradio['grammar_file'].change(load_grammar, gradio('grammar_file'), gradio('grammar_string'))
|
shared.gradio['grammar_file'].change(load_grammar, gradio('grammar_file'), gradio('grammar_string'))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user