diff --git a/modules/models_settings.py b/modules/models_settings.py index 85689b8b..b7a7d332 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -56,12 +56,13 @@ def get_model_metadata(model): model_file = list(path.glob('*.gguf'))[0] metadata = metadata_gguf.load_metadata(model_file) - if 'llama.context_length' in metadata: - model_settings['n_ctx'] = metadata['llama.context_length'] - if 'llama.rope.scale_linear' in metadata: - model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear'] - if 'llama.rope.freq_base' in metadata: - model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] + for k in metadata: + if k.endswith('context_length'): + model_settings['n_ctx'] = metadata[k] + elif k.endswith('rope.freq_base'): + model_settings['rope_freq_base'] = metadata[k] + elif k.endswith('rope.scale_linear'): + model_settings['compress_pos_emb'] = metadata[k] if 'tokenizer.chat_template' in metadata: template = metadata['tokenizer.chat_template'] eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']] @@ -77,7 +78,7 @@ def get_model_metadata(model): # Transformers metadata if hf_metadata is not None: metadata = json.loads(open(path, 'r', encoding='utf-8').read()) - for k in ['max_position_embeddings', 'max_seq_len']: + for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']: if k in metadata: model_settings['truncation_length'] = metadata[k] model_settings['max_seq_len'] = metadata[k]