Merge pull request #16 from 81300/model-download

Allow specifying the Hugging Face Git branch when downloading models
This commit is contained in:
oobabooga 2023-01-21 00:43:35 -03:00 committed by GitHub
commit 3f2c1e7170
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,8 +10,15 @@ import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import multiprocessing import multiprocessing
import tqdm import tqdm
from sys import argv import sys
import argparse
from pathlib import Path from pathlib import Path
import re
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str)
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
args = parser.parse_args()
def get_file(args): def get_file(args):
url = args[0] url = args[0]
@ -27,12 +34,32 @@ def get_file(args):
f.write(data) f.write(data)
t.close() t.close()
def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if pattern.match(branch_name):
return branch_name
else:
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
if __name__ == '__main__': if __name__ == '__main__':
model = argv[1] model = args.MODEL
if model[-1] == '/': if model[-1] == '/':
model = model[:-1] model = model[:-1]
url = f'https://huggingface.co/{model}/tree/main' branch = args.branch
output_folder = Path("models") / model.split('/')[-1] if args.branch is None:
branch = 'main'
else:
try:
branch_name = args.branch
branch = sanitize_branch_name(branch_name)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
url = f'https://huggingface.co/{model}/tree/{branch}'
if branch != 'main':
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
else:
output_folder = Path("models") / model.split('/')[-1]
if not output_folder.exists(): if not output_folder.exists():
output_folder.mkdir() output_folder.mkdir()
@ -43,7 +70,7 @@ if __name__ == '__main__':
downloads = [] downloads = []
for link in links: for link in links:
href = link.get('href')[1:] href = link.get('href')[1:]
if href.startswith(f'{model}/resolve/main'): if href.startswith(f'{model}/resolve/{branch}'):
if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href): if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
downloads.append(f'https://huggingface.co/{href}') downloads.append(f'https://huggingface.co/{href}')