diff --git a/download-model.py b/download-model.py index 9733fcfa..a83ed17a 100644 --- a/download-model.py +++ b/download-model.py @@ -10,8 +10,15 @@ import requests from bs4 import BeautifulSoup import multiprocessing import tqdm -from sys import argv +import sys +import argparse from pathlib import Path +import re + +parser = argparse.ArgumentParser() +parser.add_argument('MODEL', type=str) +parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') +args = parser.parse_args() def get_file(args): url = args[0] @@ -27,12 +34,32 @@ def get_file(args): f.write(data) t.close() +def sanitize_branch_name(branch_name): + pattern = re.compile(r"^[a-zA-Z0-9._-]+$") + if pattern.match(branch_name): + return branch_name + else: + raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") + if __name__ == '__main__': - model = argv[1] + model = args.MODEL if model[-1] == '/': model = model[:-1] - url = f'https://huggingface.co/{model}/tree/main' - output_folder = Path("models") / model.split('/')[-1] + branch = args.branch + if args.branch is None: + branch = 'main' + else: + try: + branch_name = args.branch + branch = sanitize_branch_name(branch_name) + except ValueError as err_branch: + print(f"Error: {err_branch}") + sys.exit() + url = f'https://huggingface.co/{model}/tree/{branch}' + if branch != 'main': + output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}') + else: + output_folder = Path("models") / model.split('/')[-1] if not output_folder.exists(): output_folder.mkdir() @@ -43,7 +70,7 @@ if __name__ == '__main__': downloads = [] for link in links: href = link.get('href')[1:] - if href.startswith(f'{model}/resolve/main'): + if href.startswith(f'{model}/resolve/{branch}'): if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href): downloads.append(f'https://huggingface.co/{href}')