From df123a20fccda29439d70c5590de2e33937acbed Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 11 Sep 2023 20:13:10 -0700 Subject: [PATCH] Prevent extra keys from being saved to settings.yaml --- extensions/api/blocking_api.py | 2 +- extensions/openai/models.py | 2 +- modules/evaluate.py | 2 +- server.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index 8a643c15..a91fd515 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -130,7 +130,7 @@ class Handler(BaseHTTPRequestHandler): unload_model() model_settings = get_model_metadata(shared.model_name) - shared.settings.update(model_settings) + shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) update_model_parameters(model_settings, initial=True) if shared.settings['mode'] != 'instruct': diff --git a/extensions/openai/models.py b/extensions/openai/models.py index 25263e4b..e6715a81 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -32,7 +32,7 @@ def load_model(model_name: str) -> dict: unload_model() model_settings = get_model_metadata(shared.model_name) - shared.settings.update(model_settings) + shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) update_model_parameters(model_settings, initial=True) if shared.settings['mode'] != 'instruct': diff --git a/modules/evaluate.py b/modules/evaluate.py index 609419c8..8044e203 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -67,7 +67,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): try: yield cumulative_log + f"Loading {model}...\n\n" model_settings = get_model_metadata(model) - shared.settings.update(model_settings) # hijacking the interface defaults + shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults update_model_parameters(model_settings) # hijacking the command-line arguments shared.model_name = model unload_model() diff --git a/server.py b/server.py index 597b47be..ac3bcec1 100644 --- a/server.py +++ b/server.py @@ -212,7 +212,7 @@ if __name__ == "__main__": model_name = shared.model_name model_settings = get_model_metadata(model_name) - shared.settings.update(model_settings) # hijacking the interface defaults + shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments # Load the model