Fix loading command-r context length metadata

This commit is contained in:
oobabooga 2024-04-10 21:36:32 -07:00
parent 3ae61c0338
commit 17c4319e2d

View File

@ -56,12 +56,13 @@ def get_model_metadata(model):
model_file = list(path.glob('*.gguf'))[0] model_file = list(path.glob('*.gguf'))[0]
metadata = metadata_gguf.load_metadata(model_file) metadata = metadata_gguf.load_metadata(model_file)
if 'llama.context_length' in metadata: for k in metadata:
model_settings['n_ctx'] = metadata['llama.context_length'] if k.endswith('context_length'):
if 'llama.rope.scale_linear' in metadata: model_settings['n_ctx'] = metadata[k]
model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear'] elif k.endswith('rope.freq_base'):
if 'llama.rope.freq_base' in metadata: model_settings['rope_freq_base'] = metadata[k]
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] elif k.endswith('rope.scale_linear'):
model_settings['compress_pos_emb'] = metadata[k]
if 'tokenizer.chat_template' in metadata: if 'tokenizer.chat_template' in metadata:
template = metadata['tokenizer.chat_template'] template = metadata['tokenizer.chat_template']
eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']] eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']]
@ -77,7 +78,7 @@ def get_model_metadata(model):
# Transformers metadata # Transformers metadata
if hf_metadata is not None: if hf_metadata is not None:
metadata = json.loads(open(path, 'r', encoding='utf-8').read()) metadata = json.loads(open(path, 'r', encoding='utf-8').read())
for k in ['max_position_embeddings', 'max_seq_len']: for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']:
if k in metadata: if k in metadata:
model_settings['truncation_length'] = metadata[k] model_settings['truncation_length'] = metadata[k]
model_settings['max_seq_len'] = metadata[k] model_settings['max_seq_len'] = metadata[k]