mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Save extension fields to settings.yaml on "Save UI defaults"
This commit is contained in:
parent
9e86bea8e9
commit
248742df1c
@ -6,6 +6,7 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
from transformers import is_torch_xpu_available
|
from transformers import is_torch_xpu_available
|
||||||
|
|
||||||
|
import extensions
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f:
|
||||||
@ -204,7 +205,7 @@ def apply_interface_values(state, use_persistent=False):
|
|||||||
return [state[k] if k in state else gr.update() for k in elements]
|
return [state[k] if k in state else gr.update() for k in elements]
|
||||||
|
|
||||||
|
|
||||||
def save_settings(state, preset, extensions, show_controls):
|
def save_settings(state, preset, extensions_list, show_controls):
|
||||||
output = copy.deepcopy(shared.settings)
|
output = copy.deepcopy(shared.settings)
|
||||||
exclude = ['name2', 'greeting', 'context', 'turn_template']
|
exclude = ['name2', 'greeting', 'context', 'turn_template']
|
||||||
for k in state:
|
for k in state:
|
||||||
@ -215,10 +216,19 @@ def save_settings(state, preset, extensions, show_controls):
|
|||||||
output['prompt-default'] = state['prompt_menu-default']
|
output['prompt-default'] = state['prompt_menu-default']
|
||||||
output['prompt-notebook'] = state['prompt_menu-notebook']
|
output['prompt-notebook'] = state['prompt_menu-notebook']
|
||||||
output['character'] = state['character_menu']
|
output['character'] = state['character_menu']
|
||||||
output['default_extensions'] = extensions
|
output['default_extensions'] = extensions_list
|
||||||
output['seed'] = int(output['seed'])
|
output['seed'] = int(output['seed'])
|
||||||
output['show_controls'] = show_controls
|
output['show_controls'] = show_controls
|
||||||
|
|
||||||
|
# Save extension values in the UI
|
||||||
|
for extension_name in extensions_list:
|
||||||
|
extension = getattr(extensions, extension_name).script
|
||||||
|
if hasattr(extension, 'params'):
|
||||||
|
params = getattr(extension, 'params')
|
||||||
|
for param in params:
|
||||||
|
_id = f"{extension_name}-{param}"
|
||||||
|
output[_id] = params[param]
|
||||||
|
|
||||||
return yaml.dump(output, sort_keys=False, width=float("inf"))
|
return yaml.dump(output, sort_keys=False, width=float("inf"))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user