From 775902c1f22333facf3d05431ecbe5377193f25e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 6 Feb 2024 06:49:22 -0800 Subject: [PATCH] Sampler priority: better logging, always save to presets --- modules/presets.py | 7 +++++-- modules/sampler_hijack.py | 7 +++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/modules/presets.py b/modules/presets.py index 2a4a4dde..2686355d 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -120,9 +120,12 @@ def generate_preset_yaml(state): defaults = default_preset() data = {k: state[k] for k in presets_params()} - # Remove entries that are identical to the defaults + # Remove entries that are identical to the defaults. + # sampler_priority is always saved because it is experimental + # and the default order may change. + for k in list(data.keys()): - if data[k] == defaults[k]: + if data[k] == defaults[k] and k != 'sampler_priority': del data[k] return yaml.dump(data, sort_keys=False) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 9701b034..6f8de416 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -428,16 +428,15 @@ def get_logits_warper_patch(self, generation_config): # Sort the list using the custom key function warpers = sorted(warpers, key=custom_sort_key) + if shared.args.verbose: + logger.info("WARPERS=") + pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers]) if normalize is not None: warpers.append(normalize) warpers.append(SpyLogitsWarper()) warpers = LogitsProcessorList(warpers) - if shared.args.verbose: - logger.info("WARPERS=") - pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers]) - return warpers