Improved random preset generation

This commit is contained in:
oobabooga 2024-02-06 08:51:34 -08:00
parent 4e34ae0587
commit c55b8ce932

View File

@ -1,4 +1,5 @@
import functools
import pprint
import random
from pathlib import Path
@ -90,7 +91,25 @@ def random_preset(state):
'eta_cutoff': [3, 6, 9, 12, 15, 18],
},
'flatten_distribution': {
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0],
'temperature': [0.1, 0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 5.0],
'dynamic_temperature': [
[0.1, 1],
[0.1, 1.5],
[0.1, 2],
[0.1, 5],
[0.5, 1],
[0.5, 1.5],
[0.5, 2],
[0.5, 5],
[0.8, 1],
[0.8, 1.5],
[0.8, 2],
[0.8, 5],
[1, 1.5],
[1, 2],
[1, 5]
],
'smoothing_factor': [0.2, 0.3, 0.6, 1.2]
},
'repetition': {
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
@ -106,26 +125,42 @@ def random_preset(state):
for cat in params_and_values:
choices = list(params_and_values[cat].keys())
if shared.args.loader is not None:
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
choices = [x for x in choices if loader_contains(sampler)]
if len(choices) > 0:
choice = random.choice(choices)
generate_params[choice] = random.choice(params_and_values[cat][choice])
value = random.choice(params_and_values[cat][choice])
if choice == 'dynamic_temperature':
generate_params['dynamic_temperature'] = True
generate_params['dynatemp_low'] = value[0]
generate_params['dynatemp_high'] = value[1]
else:
generate_params[choice] = value
state.update(generate_params)
logger.info("GENERATED_PRESET=")
pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(state))
return state, *[generate_params[k] for k in presets_params()]
def generate_preset_yaml(state):
def loader_contains(sampler):
if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]:
return True
else:
return sampler in loaders_samplers[shared.args.loader]
def remove_defaults(state):
defaults = default_preset()
data = {k: state[k] for k in presets_params()}
# Remove entries that are identical to the defaults.
# sampler_priority is always saved because it is experimental
# and the default order may change.
for k in list(data.keys()):
if data[k] == defaults[k] and k != 'sampler_priority':
if data[k] == defaults[k]:
del data[k]
return data
def generate_preset_yaml(state):
data = remove_defaults(state)
return yaml.dump(data, sort_keys=False)