From fffd49e64e215d2e2b87d1c8a8d92b5b5debf2ac Mon Sep 17 00:00:00 2001 From: 81300 <105078168+81300@users.noreply.github.com> Date: Fri, 20 Jan 2023 22:51:56 +0200 Subject: [PATCH 1/3] Add --branch option to the model download script --- download-model.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/download-model.py b/download-model.py index 9733fcfa..7b2e10c9 100644 --- a/download-model.py +++ b/download-model.py @@ -10,8 +10,16 @@ import requests from bs4 import BeautifulSoup import multiprocessing import tqdm +import sys from sys import argv +import argparse from pathlib import Path +import re + +parser = argparse.ArgumentParser() +parser.add_argument('MODEL', type=str) +parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') +args = parser.parse_args() def get_file(args): url = args[0] @@ -27,12 +35,32 @@ def get_file(args): f.write(data) t.close() +def sanitize_branch_name(branch_name): + pattern = re.compile(r"^[a-zA-Z0-9._-]+$") + if pattern.match(branch_name): + return branch_name + else: + raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") + if __name__ == '__main__': model = argv[1] if model[-1] == '/': model = model[:-1] - url = f'https://huggingface.co/{model}/tree/main' - output_folder = Path("models") / model.split('/')[-1] + branch = args.branch + if args.branch is None: + branch = 'main' + else: + try: + branch_name = args.branch + branch = sanitize_branch_name(branch_name) + except ValueError as err_branch: + print(f"Error: {err_branch}") + sys.exit() + url = f'https://huggingface.co/{model}/tree/{branch}' + if branch != 'main': + output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}') + else: + output_folder = Path("models") / model.split('/')[-1] if not output_folder.exists(): output_folder.mkdir() @@ -43,7 +71,7 @@ if __name__ == '__main__': downloads = [] for link in links: href = link.get('href')[1:] - if href.startswith(f'{model}/resolve/main'): + if href.startswith(f'{model}/resolve/{branch}'): if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href): downloads.append(f'https://huggingface.co/{href}') From 18ef72d7c0a3e3786bc7f4c135f69233ddcb6ca3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 21 Jan 2023 00:38:23 -0300 Subject: [PATCH 2/3] Update download-model.py --- download-model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/download-model.py b/download-model.py index 7b2e10c9..0519b18c 100644 --- a/download-model.py +++ b/download-model.py @@ -11,7 +11,6 @@ from bs4 import BeautifulSoup import multiprocessing import tqdm import sys -from sys import argv import argparse from pathlib import Path import re @@ -43,7 +42,7 @@ def sanitize_branch_name(branch_name): raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") if __name__ == '__main__': - model = argv[1] + model = args.model if model[-1] == '/': model = model[:-1] branch = args.branch From 1e541d4882cdda3c594c324a8c3a4415aecb909b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 21 Jan 2023 00:43:00 -0300 Subject: [PATCH 3/3] Update download-model.py --- download-model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/download-model.py b/download-model.py index 0519b18c..a83ed17a 100644 --- a/download-model.py +++ b/download-model.py @@ -42,7 +42,7 @@ def sanitize_branch_name(branch_name): raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") if __name__ == '__main__': - model = args.model + model = args.MODEL if model[-1] == '/': model = model[:-1] branch = args.branch