2023-06-14 01:34:35 +02:00
|
|
|
import functools
|
2024-01-07 22:07:32 +01:00
|
|
|
import pprint
|
2023-11-18 22:31:41 +01:00
|
|
|
import random
|
2023-06-14 01:34:35 +02:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
2023-11-18 22:31:41 +01:00
|
|
|
from modules import shared
|
|
|
|
from modules.loaders import loaders_samplers
|
2024-01-07 22:07:32 +01:00
|
|
|
from modules.logging_colors import logger
|
2023-11-18 22:31:41 +01:00
|
|
|
|
2023-06-14 01:34:35 +02:00
|
|
|
|
2023-08-01 04:13:29 +02:00
|
|
|
def default_preset():
|
|
|
|
return {
|
2023-06-14 01:34:35 +02:00
|
|
|
'temperature': 1,
|
2023-11-04 17:09:07 +01:00
|
|
|
'temperature_last': False,
|
2024-01-07 21:03:47 +01:00
|
|
|
'dynamic_temperature': False,
|
|
|
|
'dynamic_temperature_low': 0.1,
|
2023-06-14 01:34:35 +02:00
|
|
|
'top_p': 1,
|
2023-11-03 16:25:22 +01:00
|
|
|
'min_p': 0,
|
2023-08-06 22:22:48 +02:00
|
|
|
'top_k': 0,
|
2023-06-14 01:34:35 +02:00
|
|
|
'repetition_penalty': 1,
|
2023-10-25 17:10:28 +02:00
|
|
|
'presence_penalty': 0,
|
|
|
|
'frequency_penalty': 0,
|
2023-12-08 00:04:52 +01:00
|
|
|
'repetition_penalty_range': 1024,
|
2023-11-06 06:38:29 +01:00
|
|
|
'typical_p': 1,
|
|
|
|
'tfs': 1,
|
|
|
|
'top_a': 0,
|
|
|
|
'epsilon_cutoff': 0,
|
|
|
|
'eta_cutoff': 0,
|
2023-08-06 22:22:48 +02:00
|
|
|
'guidance_scale': 1,
|
2023-11-06 06:38:29 +01:00
|
|
|
'penalty_alpha': 0,
|
2023-06-14 01:34:35 +02:00
|
|
|
'mirostat_mode': 0,
|
2023-11-06 06:38:29 +01:00
|
|
|
'mirostat_tau': 5,
|
2023-06-14 01:34:35 +02:00
|
|
|
'mirostat_eta': 0.1,
|
2023-11-06 06:38:29 +01:00
|
|
|
'do_sample': True,
|
|
|
|
'encoder_repetition_penalty': 1,
|
|
|
|
'no_repeat_ngram_size': 0,
|
|
|
|
'min_length': 0,
|
2023-08-06 22:22:48 +02:00
|
|
|
'num_beams': 1,
|
|
|
|
'length_penalty': 1,
|
|
|
|
'early_stopping': False,
|
2023-06-14 01:34:35 +02:00
|
|
|
}
|
|
|
|
|
2023-08-01 04:13:29 +02:00
|
|
|
|
2023-08-06 22:22:48 +02:00
|
|
|
def presets_params():
|
|
|
|
return [k for k in default_preset()]
|
|
|
|
|
|
|
|
|
2023-08-01 04:13:29 +02:00
|
|
|
def load_preset(name):
|
|
|
|
generate_params = default_preset()
|
2023-07-04 05:03:30 +02:00
|
|
|
if name not in ['None', None, '']:
|
|
|
|
with open(Path(f'presets/{name}.yaml'), 'r') as infile:
|
|
|
|
preset = yaml.safe_load(infile)
|
2023-06-14 01:34:35 +02:00
|
|
|
|
2023-07-04 05:03:30 +02:00
|
|
|
for k in preset:
|
|
|
|
generate_params[k] = preset[k]
|
2023-06-14 01:34:35 +02:00
|
|
|
|
|
|
|
return generate_params
|
|
|
|
|
|
|
|
|
|
|
|
@functools.cache
|
|
|
|
def load_preset_memoized(name):
|
|
|
|
return load_preset(name)
|
|
|
|
|
|
|
|
|
|
|
|
def load_preset_for_ui(name, state):
|
|
|
|
generate_params = load_preset(name)
|
|
|
|
state.update(generate_params)
|
2023-08-06 22:22:48 +02:00
|
|
|
return state, *[generate_params[k] for k in presets_params()]
|
2023-06-14 01:34:35 +02:00
|
|
|
|
|
|
|
|
2023-11-18 22:31:41 +01:00
|
|
|
def random_preset(state):
|
|
|
|
params_and_values = {
|
|
|
|
'remove_tail_tokens': {
|
|
|
|
'top_p': [0.5, 0.8, 0.9, 0.95, 0.99],
|
|
|
|
'min_p': [0.5, 0.2, 0.1, 0.05, 0.01],
|
|
|
|
'top_k': [3, 5, 10, 20, 30, 40],
|
|
|
|
'typical_p': [0.2, 0.575, 0.95],
|
|
|
|
'tfs': [0.5, 0.8, 0.9, 0.95, 0.99],
|
|
|
|
'top_a': [0.5, 0.2, 0.1, 0.05, 0.01],
|
|
|
|
'epsilon_cutoff': [1, 3, 5, 7, 9],
|
|
|
|
'eta_cutoff': [3, 6, 9, 12, 15, 18],
|
|
|
|
},
|
|
|
|
'flatten_distribution': {
|
2024-01-07 22:07:32 +01:00
|
|
|
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 3.0, 5.0],
|
|
|
|
'dynamic_temperature_low': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 3.0],
|
2023-11-18 22:31:41 +01:00
|
|
|
},
|
|
|
|
'repetition': {
|
|
|
|
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
|
|
|
'presence_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
|
|
|
|
'frequency_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
|
|
|
|
},
|
|
|
|
'other': {
|
|
|
|
'temperature_last': [True, False],
|
2024-01-07 22:07:32 +01:00
|
|
|
'dynamic_temperature': [True, False],
|
2023-11-18 22:31:41 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
generate_params = default_preset()
|
2024-01-07 22:07:32 +01:00
|
|
|
defaults = default_preset()
|
|
|
|
|
2023-11-18 22:31:41 +01:00
|
|
|
for cat in params_and_values:
|
|
|
|
choices = list(params_and_values[cat].keys())
|
|
|
|
if shared.args.loader is not None:
|
|
|
|
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
|
|
|
|
|
|
|
|
if len(choices) > 0:
|
2024-01-07 22:07:32 +01:00
|
|
|
if cat == 'other':
|
|
|
|
N = random.randint(1, len(choices))
|
|
|
|
maybe_multiple_choices = random.sample(choices, N)
|
|
|
|
for choice in maybe_multiple_choices:
|
|
|
|
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
|
|
|
else:
|
|
|
|
choice = random.choice(choices)
|
|
|
|
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
|
|
|
|
|
|
|
# If using dynamic temperature, sample the high/low values simultaneously.
|
|
|
|
# If necessary, resample until the low is lower than the high
|
|
|
|
if generate_params['dynamic_temperature']:
|
|
|
|
generate_params['dynamic_temperature_low'] = random.choice(params_and_values['flatten_distribution']['dynamic_temperature_low'])
|
|
|
|
generate_params['temperature'] = random.choice(params_and_values['flatten_distribution']['temperature'])
|
|
|
|
while generate_params['dynamic_temperature_low'] >= generate_params['temperature']:
|
|
|
|
generate_params['dynamic_temperature_low'] = random.choice(params_and_values['flatten_distribution']['dynamic_temperature_low'])
|
|
|
|
generate_params['temperature'] = random.choice(params_and_values['flatten_distribution']['temperature'])
|
|
|
|
elif 'dynamic_temperature_low' in generate_params:
|
|
|
|
generate_params['dynamic_temperature_low'] = defaults['dynamic_temperature_low']
|
2023-11-18 22:31:41 +01:00
|
|
|
|
|
|
|
state.update(generate_params)
|
2024-01-07 22:07:32 +01:00
|
|
|
|
|
|
|
diff = {}
|
|
|
|
# Remove entries that are identical to the defaults
|
|
|
|
for k in list(generate_params.keys()):
|
|
|
|
if generate_params[k] != defaults[k]:
|
|
|
|
diff[k] = generate_params[k]
|
|
|
|
|
|
|
|
logger.info("GENERATED_PRESET=")
|
|
|
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(diff)
|
|
|
|
print()
|
|
|
|
|
2023-11-18 22:31:41 +01:00
|
|
|
return state, *[generate_params[k] for k in presets_params()]
|
|
|
|
|
|
|
|
|
2023-06-14 01:34:35 +02:00
|
|
|
def generate_preset_yaml(state):
|
2023-08-01 04:13:29 +02:00
|
|
|
defaults = default_preset()
|
2023-08-06 22:22:48 +02:00
|
|
|
data = {k: state[k] for k in presets_params()}
|
2023-08-01 04:13:29 +02:00
|
|
|
|
|
|
|
# Remove entries that are identical to the defaults
|
|
|
|
for k in list(data.keys()):
|
|
|
|
if data[k] == defaults[k]:
|
|
|
|
del data[k]
|
|
|
|
|
2023-06-14 01:34:35 +02:00
|
|
|
return yaml.dump(data, sort_keys=False)
|