Fix the download script for long lists of files on HF

This commit is contained in:
oobabooga 2023-03-10 00:41:10 -03:00
parent 9849aac0f1
commit 249c268176

View File

@ -5,7 +5,9 @@ Example:
python download-model.py facebook/opt-1.3b python download-model.py facebook/opt-1.3b
''' '''
import argparse import argparse
import base64
import json import json
import multiprocessing import multiprocessing
import re import re
@ -93,14 +95,18 @@ facebook/opt-1.3b
def get_download_links_from_huggingface(model, branch): def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co" base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor=" page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b""
links = [] links = []
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_safetensors = False has_safetensors = False
while page is not None: while True:
content = requests.get(f"{base}{page}").content content = requests.get(f"{base}{page}{cursor.decode()}").content
dict = json.loads(content) dict = json.loads(content)
if len(dict) == 0:
break
for i in range(len(dict)): for i in range(len(dict)):
fname = dict[i]['path'] fname = dict[i]['path']
@ -123,8 +129,9 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True has_pytorch = True
classifications.append('pytorch') classifications.append('pytorch')
#page = dict['nextUrl'] cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
page = None cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only # If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors: if has_pytorch and has_safetensors: