llama.cpp: read instruction template from GGUF metadata (#4975)

This commit is contained in:
oobabooga 2023-12-18 01:51:58 -03:00 committed by GitHub
parent 3f3cd4fbe4
commit f0d6ead877
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 3 deletions

View File

@ -82,8 +82,9 @@ def load_metadata(fname):
if value_type == GGUFValueType.ARRAY: if value_type == GGUFValueType.ARRAY:
ltype = GGUFValueType(struct.unpack("<I", file.read(4))[0]) ltype = GGUFValueType(struct.unpack("<I", file.read(4))[0])
length = struct.unpack("<Q", file.read(8))[0] length = struct.unpack("<Q", file.read(8))[0]
for j in range(length):
_ = get_single(ltype, file) arr = [get_single(ltype, file) for _ in range(length)]
metadata[key.decode()] = arr
else: else:
value = get_single(value_type, file) value = get_single(value_type, file)
metadata[key.decode()] = value metadata[key.decode()] = value

View File

@ -64,6 +64,16 @@ def get_model_metadata(model):
model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear'] model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear']
if 'llama.rope.freq_base' in metadata: if 'llama.rope.freq_base' in metadata:
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
if 'tokenizer.chat_template' in metadata:
template = metadata['tokenizer.chat_template']
eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']]
bos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.bos_token_id']]
template = template.replace('eos_token', "'{}'".format(eos_token))
template = template.replace('bos_token', "'{}'".format(bos_token))
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
model_settings['instruction_template_str'] = template
else: else:
# Transformers metadata # Transformers metadata
@ -114,7 +124,6 @@ def get_model_metadata(model):
template = template.replace(k, "'{}'".format(value)) template = template.replace(k, "'{}'".format(value))
template = re.sub(r'raise_exception\([^)]*\)', "''", template) template = re.sub(r'raise_exception\([^)]*\)', "''", template)
model_settings['instruction_template'] = 'Custom (obtained from model metadata)' model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
model_settings['instruction_template_str'] = template model_settings['instruction_template_str'] = template