2023-06-14 01:34:35 +02:00
|
|
|
import functools
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
|
2023-08-01 04:13:29 +02:00
|
|
|
def default_preset():
|
|
|
|
return {
|
2023-06-14 01:34:35 +02:00
|
|
|
'do_sample': True,
|
|
|
|
'temperature': 1,
|
|
|
|
'top_p': 1,
|
2023-08-06 22:22:48 +02:00
|
|
|
'top_k': 0,
|
2023-06-14 01:34:35 +02:00
|
|
|
'typical_p': 1,
|
|
|
|
'epsilon_cutoff': 0,
|
|
|
|
'eta_cutoff': 0,
|
|
|
|
'tfs': 1,
|
|
|
|
'top_a': 0,
|
|
|
|
'repetition_penalty': 1,
|
2023-06-29 18:40:13 +02:00
|
|
|
'repetition_penalty_range': 0,
|
2023-06-14 01:34:35 +02:00
|
|
|
'encoder_repetition_penalty': 1,
|
|
|
|
'no_repeat_ngram_size': 0,
|
2023-08-06 22:22:48 +02:00
|
|
|
'min_length': 0,
|
|
|
|
'guidance_scale': 1,
|
2023-06-14 01:34:35 +02:00
|
|
|
'mirostat_mode': 0,
|
|
|
|
'mirostat_tau': 5.0,
|
|
|
|
'mirostat_eta': 0.1,
|
2023-08-06 22:22:48 +02:00
|
|
|
'penalty_alpha': 0,
|
|
|
|
'num_beams': 1,
|
|
|
|
'length_penalty': 1,
|
|
|
|
'early_stopping': False,
|
2023-06-14 01:34:35 +02:00
|
|
|
}
|
|
|
|
|
2023-08-01 04:13:29 +02:00
|
|
|
|
2023-08-06 22:22:48 +02:00
|
|
|
def presets_params():
|
|
|
|
return [k for k in default_preset()]
|
|
|
|
|
|
|
|
|
2023-08-01 04:13:29 +02:00
|
|
|
def load_preset(name):
|
|
|
|
generate_params = default_preset()
|
2023-07-04 05:03:30 +02:00
|
|
|
if name not in ['None', None, '']:
|
|
|
|
with open(Path(f'presets/{name}.yaml'), 'r') as infile:
|
|
|
|
preset = yaml.safe_load(infile)
|
2023-06-14 01:34:35 +02:00
|
|
|
|
2023-07-04 05:03:30 +02:00
|
|
|
for k in preset:
|
|
|
|
generate_params[k] = preset[k]
|
2023-06-14 01:34:35 +02:00
|
|
|
|
|
|
|
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
|
|
|
return generate_params
|
|
|
|
|
|
|
|
|
|
|
|
@functools.cache
|
|
|
|
def load_preset_memoized(name):
|
|
|
|
return load_preset(name)
|
|
|
|
|
|
|
|
|
|
|
|
def load_preset_for_ui(name, state):
|
|
|
|
generate_params = load_preset(name)
|
|
|
|
state.update(generate_params)
|
2023-08-06 22:22:48 +02:00
|
|
|
return state, *[generate_params[k] for k in presets_params()]
|
2023-06-14 01:34:35 +02:00
|
|
|
|
|
|
|
|
|
|
|
def generate_preset_yaml(state):
|
2023-08-01 04:13:29 +02:00
|
|
|
defaults = default_preset()
|
2023-08-06 22:22:48 +02:00
|
|
|
data = {k: state[k] for k in presets_params()}
|
2023-08-01 04:13:29 +02:00
|
|
|
|
|
|
|
# Remove entries that are identical to the defaults
|
|
|
|
for k in list(data.keys()):
|
|
|
|
if data[k] == defaults[k]:
|
|
|
|
del data[k]
|
|
|
|
|
2023-06-14 01:34:35 +02:00
|
|
|
return yaml.dump(data, sort_keys=False)
|