diff --git a/modules/models_settings.py b/modules/models_settings.py index e3f86c31..ca87b44b 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -52,7 +52,14 @@ def get_model_metadata(model): if 'llama.rope.scale_linear' in metadata: model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear'] if 'llama.rope.freq_base' in metadata: - model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] + model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] + + # Apply user settings from models/config-user.yaml + settings = shared.user_config + for pat in settings: + if re.match(pat.lower(), model.lower()): + for k in settings[pat]: + model_settings[k] = settings[pat][k] return model_settings @@ -155,17 +162,14 @@ def save_model_settings(model, state): user_config = {} model_regex = model + '$' # For exact matches - for _dict in [user_config, shared.model_config]: - if model_regex not in _dict: - _dict[model_regex] = {} - if model_regex not in user_config: user_config[model_regex] = {} for k in ui.list_model_elements(): if k == 'loader' or k in loaders.loaders_and_params[state['loader']]: user_config[model_regex][k] = state[k] - shared.model_config[model_regex][k] = state[k] + + shared.user_config = user_config output = yaml.dump(user_config, sort_keys=False) with open(p, 'w') as f: diff --git a/modules/shared.py b/modules/shared.py index 06aafc8d..2555eca4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -259,10 +259,8 @@ with Path(f'{args.model_dir}/config.yaml') as p: with Path(f'{args.model_dir}/config-user.yaml') as p: if p.exists(): user_config = yaml.safe_load(open(p, 'r').read()) - for k in user_config: - if k in model_config: - model_config[k].update(user_config[k]) - else: - model_config[k] = user_config[k] + else: + user_config = {} model_config = OrderedDict(model_config) +user_config = OrderedDict(user_config)