Fix loading command-r context length metadata

This commit is contained in:
oobabooga 2024-04-10 21:36:32 -07:00
parent 3ae61c0338
commit 17c4319e2d

View File

@ -56,12 +56,13 @@ def get_model_metadata(model):
model_file = list(path.glob('*.gguf'))[0]
metadata = metadata_gguf.load_metadata(model_file)
if 'llama.context_length' in metadata:
model_settings['n_ctx'] = metadata['llama.context_length']
if 'llama.rope.scale_linear' in metadata:
model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear']
if 'llama.rope.freq_base' in metadata:
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
for k in metadata:
if k.endswith('context_length'):
model_settings['n_ctx'] = metadata[k]
elif k.endswith('rope.freq_base'):
model_settings['rope_freq_base'] = metadata[k]
elif k.endswith('rope.scale_linear'):
model_settings['compress_pos_emb'] = metadata[k]
if 'tokenizer.chat_template' in metadata:
template = metadata['tokenizer.chat_template']
eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']]
@ -77,7 +78,7 @@ def get_model_metadata(model):
# Transformers metadata
if hf_metadata is not None:
metadata = json.loads(open(path, 'r', encoding='utf-8').read())
for k in ['max_position_embeddings', 'max_seq_len']:
for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']:
if k in metadata:
model_settings['truncation_length'] = metadata[k]
model_settings['max_seq_len'] = metadata[k]