mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Improved random preset generation
This commit is contained in:
parent
4e34ae0587
commit
c55b8ce932
@ -1,4 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
|
import pprint
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -90,7 +91,25 @@ def random_preset(state):
|
|||||||
'eta_cutoff': [3, 6, 9, 12, 15, 18],
|
'eta_cutoff': [3, 6, 9, 12, 15, 18],
|
||||||
},
|
},
|
||||||
'flatten_distribution': {
|
'flatten_distribution': {
|
||||||
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0],
|
'temperature': [0.1, 0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 5.0],
|
||||||
|
'dynamic_temperature': [
|
||||||
|
[0.1, 1],
|
||||||
|
[0.1, 1.5],
|
||||||
|
[0.1, 2],
|
||||||
|
[0.1, 5],
|
||||||
|
[0.5, 1],
|
||||||
|
[0.5, 1.5],
|
||||||
|
[0.5, 2],
|
||||||
|
[0.5, 5],
|
||||||
|
[0.8, 1],
|
||||||
|
[0.8, 1.5],
|
||||||
|
[0.8, 2],
|
||||||
|
[0.8, 5],
|
||||||
|
[1, 1.5],
|
||||||
|
[1, 2],
|
||||||
|
[1, 5]
|
||||||
|
],
|
||||||
|
'smoothing_factor': [0.2, 0.3, 0.6, 1.2]
|
||||||
},
|
},
|
||||||
'repetition': {
|
'repetition': {
|
||||||
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
||||||
@ -106,26 +125,42 @@ def random_preset(state):
|
|||||||
for cat in params_and_values:
|
for cat in params_and_values:
|
||||||
choices = list(params_and_values[cat].keys())
|
choices = list(params_and_values[cat].keys())
|
||||||
if shared.args.loader is not None:
|
if shared.args.loader is not None:
|
||||||
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
|
choices = [x for x in choices if loader_contains(sampler)]
|
||||||
|
|
||||||
if len(choices) > 0:
|
if len(choices) > 0:
|
||||||
choice = random.choice(choices)
|
choice = random.choice(choices)
|
||||||
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
value = random.choice(params_and_values[cat][choice])
|
||||||
|
if choice == 'dynamic_temperature':
|
||||||
|
generate_params['dynamic_temperature'] = True
|
||||||
|
generate_params['dynatemp_low'] = value[0]
|
||||||
|
generate_params['dynatemp_high'] = value[1]
|
||||||
|
else:
|
||||||
|
generate_params[choice] = value
|
||||||
|
|
||||||
state.update(generate_params)
|
state.update(generate_params)
|
||||||
|
logger.info("GENERATED_PRESET=")
|
||||||
|
pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(state))
|
||||||
return state, *[generate_params[k] for k in presets_params()]
|
return state, *[generate_params[k] for k in presets_params()]
|
||||||
|
|
||||||
|
|
||||||
def generate_preset_yaml(state):
|
def loader_contains(sampler):
|
||||||
|
if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return sampler in loaders_samplers[shared.args.loader]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_defaults(state):
|
||||||
defaults = default_preset()
|
defaults = default_preset()
|
||||||
data = {k: state[k] for k in presets_params()}
|
data = {k: state[k] for k in presets_params()}
|
||||||
|
|
||||||
# Remove entries that are identical to the defaults.
|
|
||||||
# sampler_priority is always saved because it is experimental
|
|
||||||
# and the default order may change.
|
|
||||||
|
|
||||||
for k in list(data.keys()):
|
for k in list(data.keys()):
|
||||||
if data[k] == defaults[k] and k != 'sampler_priority':
|
if data[k] == defaults[k]:
|
||||||
del data[k]
|
del data[k]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def generate_preset_yaml(state):
|
||||||
|
data = remove_defaults(state)
|
||||||
return yaml.dump(data, sort_keys=False)
|
return yaml.dump(data, sort_keys=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user