Prevent extra keys from being saved to settings.yaml

This commit is contained in:
oobabooga 2023-09-11 20:13:10 -07:00
parent dae428a967
commit df123a20fc
4 changed files with 4 additions and 4 deletions

View File

@ -130,7 +130,7 @@ class Handler(BaseHTTPRequestHandler):
unload_model() unload_model()
model_settings = get_model_metadata(shared.model_name) model_settings = get_model_metadata(shared.model_name)
shared.settings.update(model_settings) shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
update_model_parameters(model_settings, initial=True) update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct': if shared.settings['mode'] != 'instruct':

View File

@ -32,7 +32,7 @@ def load_model(model_name: str) -> dict:
unload_model() unload_model()
model_settings = get_model_metadata(shared.model_name) model_settings = get_model_metadata(shared.model_name)
shared.settings.update(model_settings) shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
update_model_parameters(model_settings, initial=True) update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct': if shared.settings['mode'] != 'instruct':

View File

@ -67,7 +67,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
try: try:
yield cumulative_log + f"Loading {model}...\n\n" yield cumulative_log + f"Loading {model}...\n\n"
model_settings = get_model_metadata(model) model_settings = get_model_metadata(model)
shared.settings.update(model_settings) # hijacking the interface defaults shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults
update_model_parameters(model_settings) # hijacking the command-line arguments update_model_parameters(model_settings) # hijacking the command-line arguments
shared.model_name = model shared.model_name = model
unload_model() unload_model()

View File

@ -212,7 +212,7 @@ if __name__ == "__main__":
model_name = shared.model_name model_name = shared.model_name
model_settings = get_model_metadata(model_name) model_settings = get_model_metadata(model_name)
shared.settings.update(model_settings) # hijacking the interface defaults shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
# Load the model # Load the model