2023-01-06 23:57:31 +01:00
'''
2023-06-21 04:36:56 +02:00
Downloads models from Hugging Face to models / username_modelname .
2023-01-06 23:57:31 +01:00
Example :
2023-04-09 22:00:59 +02:00
python download - model . py facebook / opt - 1.3 b
2023-01-06 23:57:31 +01:00
'''
2023-03-10 04:41:10 +01:00
2023-02-10 19:40:03 +01:00
import argparse
2023-03-10 04:41:10 +01:00
import base64
2023-03-30 01:26:44 +02:00
import datetime
2023-03-31 06:31:47 +02:00
import hashlib
2023-02-24 18:06:42 +01:00
import json
2023-06-01 05:11:21 +02:00
import os
2023-06-21 04:36:56 +02:00
import re
2023-01-20 21:51:56 +01:00
import sys
2023-01-07 20:33:43 +01:00
from pathlib import Path
2024-04-27 17:25:28 +02:00
from time import sleep
2023-02-10 19:40:03 +01:00
import requests
import tqdm
2023-07-05 03:26:30 +02:00
from requests . adapters import HTTPAdapter
2024-04-27 17:25:28 +02:00
from requests . exceptions import ConnectionError , RequestException , Timeout
2023-03-29 03:29:20 +02:00
from tqdm . contrib . concurrent import thread_map
2023-01-20 21:51:56 +01:00
2024-04-11 23:28:10 +02:00
base = os . environ . get ( " HF_ENDPOINT " ) or " https://huggingface.co "
2023-09-16 15:06:13 +02:00
2023-06-01 05:11:21 +02:00
class ModelDownloader :
2023-07-12 20:33:25 +02:00
def __init__ ( self , max_retries = 5 ) :
2024-02-16 16:55:27 +01:00
self . max_retries = max_retries
def get_session ( self ) :
session = requests . Session ( )
if self . max_retries :
session . mount ( ' https://cdn-lfs.huggingface.co ' , HTTPAdapter ( max_retries = self . max_retries ) )
session . mount ( ' https://huggingface.co ' , HTTPAdapter ( max_retries = self . max_retries ) )
2024-01-30 18:14:11 +01:00
2023-06-01 05:11:21 +02:00
if os . getenv ( ' HF_USER ' ) is not None and os . getenv ( ' HF_PASS ' ) is not None :
2024-02-16 16:55:27 +01:00
session . auth = ( os . getenv ( ' HF_USER ' ) , os . getenv ( ' HF_PASS ' ) )
2024-01-30 18:14:11 +01:00
try :
from huggingface_hub import get_token
token = get_token ( )
except ImportError :
token = os . getenv ( " HF_TOKEN " )
if token is not None :
2024-02-16 16:55:27 +01:00
session . headers = { ' authorization ' : f ' Bearer { token } ' }
return session
2023-06-01 05:11:21 +02:00
def sanitize_model_and_branch_names ( self , model , branch ) :
if model [ - 1 ] == ' / ' :
model = model [ : - 1 ]
2023-09-16 15:06:13 +02:00
if model . startswith ( base + ' / ' ) :
model = model [ len ( base ) + 1 : ]
model_parts = model . split ( " : " )
model = model_parts [ 0 ] if len ( model_parts ) > 0 else model
branch = model_parts [ 1 ] if len ( model_parts ) > 1 else branch
2023-06-01 05:11:21 +02:00
if branch is None :
branch = " main "
else :
pattern = re . compile ( r " ^[a-zA-Z0-9._-]+$ " )
if not pattern . match ( branch ) :
raise ValueError (
" Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed. " )
return model , branch
2023-08-30 03:57:58 +02:00
def get_download_links_from_huggingface ( self , model , branch , text_only = False , specific_file = None ) :
2024-02-16 16:55:27 +01:00
session = self . get_session ( )
2023-06-01 05:11:21 +02:00
page = f " /api/models/ { model } /tree/ { branch } "
cursor = b " "
links = [ ]
sha256 = [ ]
classifications = [ ]
has_pytorch = False
has_pt = False
2023-08-26 16:15:11 +02:00
has_gguf = False
2023-06-01 05:11:21 +02:00
has_safetensors = False
is_lora = False
while True :
url = f " { base } { page } " + ( f " ?cursor= { cursor . decode ( ) } " if cursor else " " )
2024-02-16 16:55:27 +01:00
r = session . get ( url , timeout = 10 )
2023-06-01 05:11:21 +02:00
r . raise_for_status ( )
content = r . content
dict = json . loads ( content )
if len ( dict ) == 0 :
break
for i in range ( len ( dict ) ) :
fname = dict [ i ] [ ' path ' ]
2023-08-30 04:32:36 +02:00
if specific_file not in [ None , ' ' ] and fname != specific_file :
2023-08-30 03:57:58 +02:00
continue
2023-06-01 05:11:21 +02:00
if not is_lora and fname . endswith ( ( ' adapter_config.json ' , ' adapter_model.bin ' ) ) :
is_lora = True
2023-08-04 02:10:57 +02:00
is_pytorch = re . match ( r " (pytorch|adapter|gptq)_model.* \ .bin " , fname )
is_safetensors = re . match ( r " .* \ .safetensors " , fname )
is_pt = re . match ( r " .* \ .pt " , fname )
2023-08-26 10:06:59 +02:00
is_gguf = re . match ( r ' .* \ .gguf ' , fname )
2023-09-28 23:03:18 +02:00
is_tiktoken = re . match ( r " .* \ .tiktoken " , fname )
is_tokenizer = re . match ( r " (tokenizer|ice|spiece).* \ .model " , fname ) or is_tiktoken
2023-08-04 02:10:57 +02:00
is_text = re . match ( r " .* \ .(txt|json|py|md) " , fname ) or is_tokenizer
2023-09-11 16:30:56 +02:00
if any ( ( is_pytorch , is_safetensors , is_pt , is_gguf , is_tokenizer , is_text ) ) :
2023-06-01 05:11:21 +02:00
if ' lfs ' in dict [ i ] :
sha256 . append ( [ fname , dict [ i ] [ ' lfs ' ] [ ' oid ' ] ] )
if is_text :
2024-04-11 23:28:10 +02:00
links . append ( f " { base } / { model } /resolve/ { branch } / { fname } " )
2023-06-01 05:11:21 +02:00
classifications . append ( ' text ' )
continue
if not text_only :
2024-04-11 23:28:10 +02:00
links . append ( f " { base } / { model } /resolve/ { branch } / { fname } " )
2023-06-01 05:11:21 +02:00
if is_safetensors :
has_safetensors = True
classifications . append ( ' safetensors ' )
elif is_pytorch :
has_pytorch = True
classifications . append ( ' pytorch ' )
elif is_pt :
has_pt = True
classifications . append ( ' pt ' )
2023-08-26 10:06:59 +02:00
elif is_gguf :
2023-08-26 16:15:11 +02:00
has_gguf = True
2023-08-26 10:06:59 +02:00
classifications . append ( ' gguf ' )
2023-06-01 05:11:21 +02:00
cursor = base64 . b64encode ( f ' {{ " file_name " : " { dict [ - 1 ] [ " path " ] } " }} ' . encode ( ) ) + b ' :50 '
cursor = base64 . b64encode ( cursor )
cursor = cursor . replace ( b ' = ' , b ' % 3D ' )
# If both pytorch and safetensors are available, download safetensors only
2024-05-12 19:43:50 +02:00
# Also if GGUF and safetensors are available, download only safetensors
# (why do people do this?)
if ( has_pytorch or has_pt or has_gguf ) and has_safetensors :
has_gguf = False
2023-06-01 05:11:21 +02:00
for i in range ( len ( classifications ) - 1 , - 1 , - 1 ) :
2024-05-12 19:43:50 +02:00
if classifications [ i ] in [ ' pytorch ' , ' pt ' , ' gguf ' ] :
2023-06-01 05:11:21 +02:00
links . pop ( i )
2023-12-06 06:09:12 +01:00
# For GGUF, try to download only the Q4_K_M if no specific file is specified.
2023-12-08 14:01:25 +01:00
# If not present, exclude all GGUFs, as that's likely a repository with both
2023-12-06 06:09:12 +01:00
# GGUF and fp16 files.
2023-10-22 18:06:20 +02:00
if has_gguf and specific_file is None :
2023-12-06 06:09:12 +01:00
has_q4km = False
2023-10-22 18:06:20 +02:00
for i in range ( len ( classifications ) - 1 , - 1 , - 1 ) :
2023-12-06 06:09:12 +01:00
if ' q4_k_m ' in links [ i ] . lower ( ) :
has_q4km = True
if has_q4km :
for i in range ( len ( classifications ) - 1 , - 1 , - 1 ) :
2023-12-08 14:01:25 +01:00
if ' q4_k_m ' not in links [ i ] . lower ( ) :
2023-12-06 06:09:12 +01:00
links . pop ( i )
else :
for i in range ( len ( classifications ) - 1 , - 1 , - 1 ) :
if links [ i ] . lower ( ) . endswith ( ' .gguf ' ) :
links . pop ( i )
2023-10-22 18:06:20 +02:00
2023-09-11 16:30:56 +02:00
is_llamacpp = has_gguf and specific_file is not None
return links , sha256 , is_lora , is_llamacpp
2023-06-01 05:11:21 +02:00
2024-05-24 05:42:46 +02:00
def get_output_folder ( self , model , branch , is_lora , is_llamacpp = False , model_dir = None ) :
if model_dir :
base_folder = model_dir
else :
base_folder = ' models ' if not is_lora else ' loras '
2023-06-01 05:11:21 +02:00
2023-09-11 16:57:38 +02:00
# If the model is of type GGUF, save directly in the base_folder
2023-08-30 03:57:58 +02:00
if is_llamacpp :
return Path ( base_folder )
2023-06-01 05:11:21 +02:00
output_folder = f " { ' _ ' . join ( model . split ( ' / ' ) [ - 2 : ] ) } "
if branch != ' main ' :
output_folder + = f ' _ { branch } '
2023-06-21 04:25:58 +02:00
2023-06-01 05:11:21 +02:00
output_folder = Path ( base_folder ) / output_folder
return output_folder
def get_single_file ( self , url , output_folder , start_from_scratch = False ) :
filename = Path ( url . rsplit ( ' / ' , 1 ) [ 1 ] )
output_path = output_folder / filename
2024-04-27 17:25:28 +02:00
max_retries = 7
attempt = 0
while attempt < max_retries :
attempt + = 1
session = self . get_session ( )
headers = { }
mode = ' wb '
2024-05-05 04:25:04 +02:00
try :
if output_path . exists ( ) and not start_from_scratch :
# Resume download
r = session . get ( url , stream = True , timeout = 20 )
total_size = int ( r . headers . get ( ' content-length ' , 0 ) )
if output_path . stat ( ) . st_size > = total_size :
return
2024-04-27 17:25:28 +02:00
2024-05-05 04:25:04 +02:00
headers = { ' Range ' : f ' bytes= { output_path . stat ( ) . st_size } - ' }
mode = ' ab '
2024-04-27 17:25:28 +02:00
with session . get ( url , stream = True , headers = headers , timeout = 30 ) as r :
r . raise_for_status ( ) # If status is not 2xx, raise an error
total_size = int ( r . headers . get ( ' content-length ' , 0 ) )
block_size = 1024 * 1024 # 1MB
tqdm_kwargs = {
' total ' : total_size ,
' unit ' : ' iB ' ,
' unit_scale ' : True ,
' bar_format ' : ' {l_bar} {bar} | {n_fmt} / {total_fmt} {rate_fmt} '
}
if ' COLAB_GPU ' in os . environ :
tqdm_kwargs . update ( {
' position ' : 0 ,
' leave ' : True
} )
with open ( output_path , mode ) as f :
with tqdm . tqdm ( * * tqdm_kwargs ) as t :
count = 0
for data in r . iter_content ( block_size ) :
f . write ( data )
t . update ( len ( data ) )
if total_size != 0 and self . progress_bar is not None :
count + = len ( data )
self . progress_bar ( float ( count ) / float ( total_size ) , f " { filename } " )
break # Exit loop if successful
except ( RequestException , ConnectionError , Timeout ) as e :
print ( f " Error downloading { filename } : { e } . " )
print ( f " That was attempt { attempt } / { max_retries } . " , end = ' ' )
if attempt < max_retries :
print ( f " Retry begins in { 2 * * attempt } seconds. " )
sleep ( 2 * * attempt )
else :
print ( " Failed to download after the maximum number of attempts. " )
2023-06-01 05:11:21 +02:00
2023-10-10 22:52:10 +02:00
def start_download_threads ( self , file_list , output_folder , start_from_scratch = False , threads = 4 ) :
2023-06-01 05:11:21 +02:00
thread_map ( lambda url : self . get_single_file ( url , output_folder , start_from_scratch = start_from_scratch ) , file_list , max_workers = threads , disable = True )
2023-10-10 22:52:10 +02:00
def download_model_files ( self , model , branch , links , sha256 , output_folder , progress_bar = None , start_from_scratch = False , threads = 4 , specific_file = None , is_llamacpp = False ) :
2023-06-21 03:59:14 +02:00
self . progress_bar = progress_bar
2023-06-21 04:25:58 +02:00
2023-08-30 04:32:36 +02:00
# Create the folder and writing the metadata
2023-06-21 04:14:18 +02:00
output_folder . mkdir ( parents = True , exist_ok = True )
2023-06-21 04:25:58 +02:00
2023-08-30 04:32:36 +02:00
if not is_llamacpp :
metadata = f ' url: https://huggingface.co/ { model } \n ' \
f ' branch: { branch } \n ' \
f ' download date: { datetime . datetime . now ( ) . strftime ( " % Y- % m- %d % H: % M: % S " ) } \n '
2023-06-21 04:25:58 +02:00
2023-08-30 04:32:36 +02:00
sha256_str = ' \n ' . join ( [ f ' { item [ 1 ] } { item [ 0 ] } ' for item in sha256 ] )
if sha256_str :
metadata + = f ' sha256sum: \n { sha256_str } '
metadata + = ' \n '
( output_folder / ' huggingface-metadata.txt ' ) . write_text ( metadata )
2023-06-01 05:11:21 +02:00
2023-08-30 03:57:58 +02:00
if specific_file :
print ( f " Downloading { specific_file } to { output_folder } " )
else :
print ( f " Downloading the model to { output_folder } " )
2023-06-01 05:11:21 +02:00
self . start_download_threads ( links , output_folder , start_from_scratch = start_from_scratch , threads = threads )
def check_model_files ( self , model , branch , links , sha256 , output_folder ) :
# Validate the checksums
validated = True
2023-04-09 21:59:59 +02:00
for i in range ( len ( sha256 ) ) :
2023-06-01 05:11:21 +02:00
fpath = ( output_folder / sha256 [ i ] [ 0 ] )
if not fpath . exists ( ) :
print ( f " The following file is missing: { fpath } " )
2023-03-31 06:31:47 +02:00
validated = False
2023-06-01 05:11:21 +02:00
continue
2023-02-24 18:06:42 +01:00
2023-06-01 05:11:21 +02:00
with open ( output_folder / sha256 [ i ] [ 0 ] , " rb " ) as f :
2024-02-26 14:54:33 +01:00
bytes = f . read ( )
file_hash = hashlib . sha256 ( bytes ) . hexdigest ( )
2023-06-01 05:11:21 +02:00
if file_hash != sha256 [ i ] [ 1 ] :
print ( f ' Checksum failed: { sha256 [ i ] [ 0 ] } { sha256 [ i ] [ 1 ] } ' )
validated = False
else :
print ( f ' Checksum validated: { sha256 [ i ] [ 0 ] } { sha256 [ i ] [ 1 ] } ' )
if validated :
print ( ' [+] Validated checksums of all model files! ' )
else :
print ( ' [-] Invalid checksums. Rerun download-model.py with the --clean flag. ' )
2023-04-01 03:52:52 +02:00
2023-04-09 21:59:59 +02:00
if __name__ == ' __main__ ' :
2023-04-10 16:36:39 +02:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' MODEL ' , type = str , default = None , nargs = ' ? ' )
parser . add_argument ( ' --branch ' , type = str , default = ' main ' , help = ' Name of the Git branch to download from. ' )
2023-10-10 22:52:10 +02:00
parser . add_argument ( ' --threads ' , type = int , default = 4 , help = ' Number of files to download simultaneously. ' )
2023-04-10 16:36:39 +02:00
parser . add_argument ( ' --text-only ' , action = ' store_true ' , help = ' Only download text files (txt/json). ' )
2023-08-30 03:57:58 +02:00
parser . add_argument ( ' --specific-file ' , type = str , default = None , help = ' Name of the specific file to download (if not provided, downloads all). ' )
2024-05-24 05:42:46 +02:00
parser . add_argument ( ' --output ' , type = str , default = None , help = ' Save the model files to this folder. ' )
parser . add_argument ( ' --model-dir ' , type = str , default = None , help = ' Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models). ' )
2023-04-10 16:36:39 +02:00
parser . add_argument ( ' --clean ' , action = ' store_true ' , help = ' Does not resume the previous download. ' )
parser . add_argument ( ' --check ' , action = ' store_true ' , help = ' Validates the checksums of model files. ' )
2023-07-05 03:26:30 +02:00
parser . add_argument ( ' --max-retries ' , type = int , default = 5 , help = ' Max retries count when get error in download time. ' )
2023-04-10 16:36:39 +02:00
args = parser . parse_args ( )
2023-04-09 21:59:59 +02:00
branch = args . branch
model = args . MODEL
2023-08-30 03:57:58 +02:00
specific_file = args . specific_file
2023-04-09 21:59:59 +02:00
2023-06-24 15:09:34 +02:00
if model is None :
print ( " Error: Please specify the model you ' d like to download (e.g. ' python download-model.py facebook/opt-1.3b ' ). " )
sys . exit ( )
2023-07-05 03:26:30 +02:00
downloader = ModelDownloader ( max_retries = args . max_retries )
2023-08-30 03:57:58 +02:00
# Clean up the model/branch names
2023-04-09 21:59:59 +02:00
try :
2023-06-01 05:11:21 +02:00
model , branch = downloader . sanitize_model_and_branch_names ( model , branch )
2023-04-09 21:59:59 +02:00
except ValueError as err_branch :
print ( f " Error: { err_branch } " )
sys . exit ( )
2023-08-30 03:57:58 +02:00
# Get the download links from Hugging Face
links , sha256 , is_lora , is_llamacpp = downloader . get_download_links_from_huggingface ( model , branch , text_only = args . text_only , specific_file = specific_file )
2023-04-09 21:59:59 +02:00
2023-08-30 03:57:58 +02:00
# Get the output folder
2024-02-16 16:43:24 +01:00
if args . output :
output_folder = Path ( args . output )
else :
2024-05-24 05:42:46 +02:00
output_folder = downloader . get_output_folder ( model , branch , is_lora , is_llamacpp = is_llamacpp , model_dir = args . model_dir )
2023-04-09 21:59:59 +02:00
if args . check :
# Check previously downloaded files
2023-06-01 05:11:21 +02:00
downloader . check_model_files ( model , branch , links , sha256 , output_folder )
2023-04-09 21:59:59 +02:00
else :
# Download files
2023-08-30 04:32:36 +02:00
downloader . download_model_files ( model , branch , links , sha256 , output_folder , specific_file = specific_file , threads = args . threads , is_llamacpp = is_llamacpp )