mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-26 14:20:40 +01:00
Improved random preset generation
This commit is contained in:
parent
4e34ae0587
commit
c55b8ce932
@ -1,4 +1,5 @@
|
||||
import functools
|
||||
import pprint
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
@ -90,7 +91,25 @@ def random_preset(state):
|
||||
'eta_cutoff': [3, 6, 9, 12, 15, 18],
|
||||
},
|
||||
'flatten_distribution': {
|
||||
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0],
|
||||
'temperature': [0.1, 0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 5.0],
|
||||
'dynamic_temperature': [
|
||||
[0.1, 1],
|
||||
[0.1, 1.5],
|
||||
[0.1, 2],
|
||||
[0.1, 5],
|
||||
[0.5, 1],
|
||||
[0.5, 1.5],
|
||||
[0.5, 2],
|
||||
[0.5, 5],
|
||||
[0.8, 1],
|
||||
[0.8, 1.5],
|
||||
[0.8, 2],
|
||||
[0.8, 5],
|
||||
[1, 1.5],
|
||||
[1, 2],
|
||||
[1, 5]
|
||||
],
|
||||
'smoothing_factor': [0.2, 0.3, 0.6, 1.2]
|
||||
},
|
||||
'repetition': {
|
||||
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
||||
@ -106,26 +125,42 @@ def random_preset(state):
|
||||
for cat in params_and_values:
|
||||
choices = list(params_and_values[cat].keys())
|
||||
if shared.args.loader is not None:
|
||||
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
|
||||
choices = [x for x in choices if loader_contains(sampler)]
|
||||
|
||||
if len(choices) > 0:
|
||||
choice = random.choice(choices)
|
||||
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
||||
value = random.choice(params_and_values[cat][choice])
|
||||
if choice == 'dynamic_temperature':
|
||||
generate_params['dynamic_temperature'] = True
|
||||
generate_params['dynatemp_low'] = value[0]
|
||||
generate_params['dynatemp_high'] = value[1]
|
||||
else:
|
||||
generate_params[choice] = value
|
||||
|
||||
state.update(generate_params)
|
||||
logger.info("GENERATED_PRESET=")
|
||||
pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(state))
|
||||
return state, *[generate_params[k] for k in presets_params()]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
def loader_contains(sampler):
|
||||
if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]:
|
||||
return True
|
||||
else:
|
||||
return sampler in loaders_samplers[shared.args.loader]
|
||||
|
||||
|
||||
def remove_defaults(state):
|
||||
defaults = default_preset()
|
||||
data = {k: state[k] for k in presets_params()}
|
||||
|
||||
# Remove entries that are identical to the defaults.
|
||||
# sampler_priority is always saved because it is experimental
|
||||
# and the default order may change.
|
||||
|
||||
for k in list(data.keys()):
|
||||
if data[k] == defaults[k] and k != 'sampler_priority':
|
||||
if data[k] == defaults[k]:
|
||||
del data[k]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
data = remove_defaults(state)
|
||||
return yaml.dump(data, sort_keys=False)
|
||||
|
Loading…
Reference in New Issue
Block a user