diff --git a/modules/presets.py b/modules/presets.py index 0af2928a..072b15fd 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -4,8 +4,8 @@ from pathlib import Path import yaml -def load_preset(name): - generate_params = { +def default_preset(): + return { 'do_sample': True, 'temperature': 1, 'top_p': 1, @@ -29,6 +29,9 @@ def load_preset(name): 'mirostat_eta': 0.1, } + +def load_preset(name): + generate_params = default_preset() if name not in ['None', None, '']: with open(Path(f'presets/{name}.yaml'), 'r') as infile: preset = yaml.safe_load(infile) @@ -52,5 +55,12 @@ def load_preset_for_ui(name, state): def generate_preset_yaml(state): + defaults = default_preset() data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']} + + # Remove entries that are identical to the defaults + for k in list(data.keys()): + if data[k] == defaults[k]: + del data[k] + return yaml.dump(data, sort_keys=False)