diff --git a/download-model.py b/download-model.py index 52c7a79e..0f40ab50 100644 --- a/download-model.py +++ b/download-model.py @@ -9,6 +9,7 @@ python download-model.py facebook/opt-1.3b import argparse import base64 import datetime +import hashlib import json import re import sys @@ -24,11 +25,28 @@ parser.add_argument('--branch', type=str, default='main', help='Name of the Git parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') +parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') +parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') args = parser.parse_args() def get_file(url, output_folder): - r = requests.get(url, stream=True) - with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f: + filename = Path(url.rsplit('/', 1)[1]) + output_path = output_folder / filename + if output_path.exists() and not args.clean: + # Check if the file has already been downloaded completely + r = requests.get(url, stream=True) + total_size = int(r.headers.get('content-length', 0)) + if output_path.stat().st_size >= total_size: + return + # Otherwise, resume the download from where it left off + headers = {'Range': f'bytes={output_path.stat().st_size}-'} + mode = 'ab' + else: + headers = {} + mode = 'wb' + + r = requests.get(url, stream=True, headers=headers) + with open(output_path, mode) as f: total_size = int(r.headers.get('content-length', 0)) block_size = 1024 with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: @@ -154,7 +172,7 @@ def get_download_links_from_huggingface(model, branch): return links, sha256, is_lora def download_files(file_list, output_folder, num_threads=8): - thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads) + thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True) if __name__ == '__main__': model = args.MODEL @@ -184,22 +202,48 @@ if __name__ == '__main__': output_folder = f"{'_'.join(model.split('/')[-2:])}" if branch != 'main': output_folder += f'_{branch}' - - # Creating the folder and writing the metadata output_folder = Path(base_folder) / output_folder - if not output_folder.exists(): - output_folder.mkdir() - with open(output_folder / 'huggingface-metadata.txt', 'w') as f: - f.write(f'url: https://huggingface.co/{model}\n') - f.write(f'branch: {branch}\n') - f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n') - sha256_str = '' - for i in range(len(sha256)): - sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n' - if sha256_str != '': - f.write(f'sha256sum:\n{sha256_str}') - # Downloading the files - print(f"Downloading the model to {output_folder}") - download_files(links, output_folder, args.threads) - print() + if args.check: + # Validate the checksums + validated = True + for i in range(len(sha256)): + fpath = (output_folder / sha256[i][0]) + + if not fpath.exists(): + print(f"The following file is missing: {fpath}") + validated = False + continue + + with open(output_folder / sha256[i][0], "rb") as f: + bytes = f.read() + file_hash = hashlib.sha256(bytes).hexdigest() + if file_hash != sha256[i][1]: + print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}') + validated = False + else: + print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}') + + if validated: + print('[+] Validated checksums of all model files!') + else: + print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.') + + else: + + # Creating the folder and writing the metadata + if not output_folder.exists(): + output_folder.mkdir() + with open(output_folder / 'huggingface-metadata.txt', 'w') as f: + f.write(f'url: https://huggingface.co/{model}\n') + f.write(f'branch: {branch}\n') + f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n') + sha256_str = '' + for i in range(len(sha256)): + sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n' + if sha256_str != '': + f.write(f'sha256sum:\n{sha256_str}') + + # Downloading the files + print(f"Downloading the model to {output_folder}") + download_files(links, output_folder, args.threads) \ No newline at end of file