Sampler priority: better logging, always save to presets

This commit is contained in:
oobabooga 2024-02-06 06:49:22 -08:00
parent acfbe6b3b3
commit 775902c1f2
2 changed files with 8 additions and 6 deletions

View File

@ -120,9 +120,12 @@ def generate_preset_yaml(state):
defaults = default_preset() defaults = default_preset()
data = {k: state[k] for k in presets_params()} data = {k: state[k] for k in presets_params()}
# Remove entries that are identical to the defaults # Remove entries that are identical to the defaults.
# sampler_priority is always saved because it is experimental
# and the default order may change.
for k in list(data.keys()): for k in list(data.keys()):
if data[k] == defaults[k]: if data[k] == defaults[k] and k != 'sampler_priority':
del data[k] del data[k]
return yaml.dump(data, sort_keys=False) return yaml.dump(data, sort_keys=False)

View File

@ -428,16 +428,15 @@ def get_logits_warper_patch(self, generation_config):
# Sort the list using the custom key function # Sort the list using the custom key function
warpers = sorted(warpers, key=custom_sort_key) warpers = sorted(warpers, key=custom_sort_key)
if shared.args.verbose:
logger.info("WARPERS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])
if normalize is not None: if normalize is not None:
warpers.append(normalize) warpers.append(normalize)
warpers.append(SpyLogitsWarper()) warpers.append(SpyLogitsWarper())
warpers = LogitsProcessorList(warpers) warpers = LogitsProcessorList(warpers)
if shared.args.verbose:
logger.info("WARPERS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])
return warpers return warpers