From c55b8ce932d90e4211eedf52ddd50abed3ac0876 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 6 Feb 2024 08:51:34 -0800 Subject: [PATCH] Improved random preset generation --- modules/presets.py | 53 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/modules/presets.py b/modules/presets.py index 2686355d..fe25ac5c 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -1,4 +1,5 @@ import functools +import pprint import random from pathlib import Path @@ -90,7 +91,25 @@ def random_preset(state): 'eta_cutoff': [3, 6, 9, 12, 15, 18], }, 'flatten_distribution': { - 'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0], + 'temperature': [0.1, 0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 5.0], + 'dynamic_temperature': [ + [0.1, 1], + [0.1, 1.5], + [0.1, 2], + [0.1, 5], + [0.5, 1], + [0.5, 1.5], + [0.5, 2], + [0.5, 5], + [0.8, 1], + [0.8, 1.5], + [0.8, 2], + [0.8, 5], + [1, 1.5], + [1, 2], + [1, 5] + ], + 'smoothing_factor': [0.2, 0.3, 0.6, 1.2] }, 'repetition': { 'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25], @@ -106,26 +125,42 @@ def random_preset(state): for cat in params_and_values: choices = list(params_and_values[cat].keys()) if shared.args.loader is not None: - choices = [x for x in choices if x in loaders_samplers[shared.args.loader]] + choices = [x for x in choices if loader_contains(sampler)] if len(choices) > 0: choice = random.choice(choices) - generate_params[choice] = random.choice(params_and_values[cat][choice]) + value = random.choice(params_and_values[cat][choice]) + if choice == 'dynamic_temperature': + generate_params['dynamic_temperature'] = True + generate_params['dynatemp_low'] = value[0] + generate_params['dynatemp_high'] = value[1] + else: + generate_params[choice] = value state.update(generate_params) + logger.info("GENERATED_PRESET=") + pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(state)) return state, *[generate_params[k] for k in presets_params()] -def generate_preset_yaml(state): +def loader_contains(sampler): + if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]: + return True + else: + return sampler in loaders_samplers[shared.args.loader] + + +def remove_defaults(state): defaults = default_preset() data = {k: state[k] for k in presets_params()} - # Remove entries that are identical to the defaults. - # sampler_priority is always saved because it is experimental - # and the default order may change. - for k in list(data.keys()): - if data[k] == defaults[k] and k != 'sampler_priority': + if data[k] == defaults[k]: del data[k] + return data + + +def generate_preset_yaml(state): + data = remove_defaults(state) return yaml.dump(data, sort_keys=False)