Merge remote-tracking branch 'refs/remotes/origin/dev' into dev

This commit is contained in:
oobabooga 2023-12-17 21:05:10 -08:00
commit da1c8d77ea
2 changed files with 13 additions and 3 deletions

View File

@ -82,8 +82,9 @@ def load_metadata(fname):
if value_type == GGUFValueType.ARRAY:
ltype = GGUFValueType(struct.unpack("<I", file.read(4))[0])
length = struct.unpack("<Q", file.read(8))[0]
for j in range(length):
_ = get_single(ltype, file)
arr = [get_single(ltype, file) for _ in range(length)]
metadata[key.decode()] = arr
else:
value = get_single(value_type, file)
metadata[key.decode()] = value

View File

@ -64,6 +64,16 @@ def get_model_metadata(model):
model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear']
if 'llama.rope.freq_base' in metadata:
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
if 'tokenizer.chat_template' in metadata:
template = metadata['tokenizer.chat_template']
eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']]
bos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.bos_token_id']]
template = template.replace('eos_token', "'{}'".format(eos_token))
template = template.replace('bos_token', "'{}'".format(bos_token))
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
model_settings['instruction_template_str'] = template
else:
# Transformers metadata
@ -114,7 +124,6 @@ def get_model_metadata(model):
template = template.replace(k, "'{}'".format(value))
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
model_settings['instruction_template_str'] = template