2024-12-30 04:40:34 +01:00
#!/usr/bin/env python
'''
Fetches the Jinja chat template of a HuggingFace model .
If a model has multiple chat templates , you can specify the variant name .
Syntax :
. / scripts / get_hf_chat_template . py model_id [ variant ]
Examples :
. / scripts / get_hf_chat_template . py NousResearch / Meta - Llama - 3 - 8 B - Instruct
. / scripts / get_hf_chat_template . py NousResearch / Hermes - 3 - Llama - 3.1 - 8 B tool_use
. / scripts / get_hf_chat_template . py meta - llama / Llama - 3.2 - 3 B - Instruct
'''
import json
import re
import sys
def get_hf_chat_template ( model_id , variant = None ) :
try :
# Use huggingface_hub library if available.
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
from huggingface_hub import hf_hub_download
with open ( hf_hub_download ( repo_id = model_id , filename = " tokenizer_config.json " ) ) as f :
config_str = f . read ( )
except ImportError :
import requests
assert re . match ( r " ^[ \ w.-]+/[ \ w.-]+$ " , model_id ) , f " Invalid model ID: { model_id } "
response = requests . get ( f " https://huggingface.co/ { model_id } /resolve/main/tokenizer_config.json " )
if response . status_code == 401 :
raise Exception ( ' Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub` ' )
response . raise_for_status ( )
config_str = response . text
try :
config = json . loads ( config_str )
except json . JSONDecodeError :
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
# (Remove extra '}' near the end of the file)
config = json . loads ( re . sub ( r ' \ }([ \ n \ s]* \ }[ \ n \ s]* \ ],[ \ n \ s]* " clean_up_tokenization_spaces " ) ' , r ' \ 1 ' , config_str ) )
chat_template = config [ ' chat_template ' ]
if isinstance ( chat_template , str ) :
return chat_template
else :
variants = {
ct [ ' name ' ] : ct [ ' template ' ]
for ct in chat_template
}
def format_variants ( ) :
return ' , ' . join ( f ' " { v } " ' for v in variants . keys ( ) )
if variant is None :
if ' default ' not in variants :
raise Exception ( f ' Please specify a chat template variant (one of { format_variants ( ) } ) ' )
variant = ' default '
2024-12-30 05:10:35 +01:00
sys . stderr . write ( f ' Note: picked " default " chat template variant (out of { format_variants ( ) } ) \n ' )
2024-12-30 04:40:34 +01:00
elif variant not in variants :
raise Exception ( f " Variant { variant } not found in chat template (found { format_variants ( ) } ) " )
return variants [ variant ]
def main ( args ) :
if len ( args ) < 1 :
raise ValueError ( " Please provide a model ID and an optional variant name " )
model_id = args [ 0 ]
variant = None if len ( args ) < 2 else args [ 1 ]
template = get_hf_chat_template ( model_id , variant )
2024-12-30 05:10:35 +01:00
sys . stdout . write ( template )
2024-12-30 04:40:34 +01:00
if __name__ == ' __main__ ' :
main ( sys . argv [ 1 : ] )