Add --branch option to the model download script

This commit is contained in:
81300 2023-01-20 22:51:56 +02:00
parent c0f2367b54
commit fffd49e64e
No known key found for this signature in database

View File

@ -10,8 +10,16 @@ import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import multiprocessing import multiprocessing
import tqdm import tqdm
import sys
from sys import argv from sys import argv
import argparse
from pathlib import Path from pathlib import Path
import re
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str)
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
args = parser.parse_args()
def get_file(args): def get_file(args):
url = args[0] url = args[0]
@ -27,12 +35,32 @@ def get_file(args):
f.write(data) f.write(data)
t.close() t.close()
def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if pattern.match(branch_name):
return branch_name
else:
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
if __name__ == '__main__': if __name__ == '__main__':
model = argv[1] model = argv[1]
if model[-1] == '/': if model[-1] == '/':
model = model[:-1] model = model[:-1]
url = f'https://huggingface.co/{model}/tree/main' branch = args.branch
output_folder = Path("models") / model.split('/')[-1] if args.branch is None:
branch = 'main'
else:
try:
branch_name = args.branch
branch = sanitize_branch_name(branch_name)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
url = f'https://huggingface.co/{model}/tree/{branch}'
if branch != 'main':
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
else:
output_folder = Path("models") / model.split('/')[-1]
if not output_folder.exists(): if not output_folder.exists():
output_folder.mkdir() output_folder.mkdir()
@ -43,7 +71,7 @@ if __name__ == '__main__':
downloads = [] downloads = []
for link in links: for link in links:
href = link.get('href')[1:] href = link.get('href')[1:]
if href.startswith(f'{model}/resolve/main'): if href.startswith(f'{model}/resolve/{branch}'):
if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href): if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href):
downloads.append(f'https://huggingface.co/{href}') downloads.append(f'https://huggingface.co/{href}')