Refactor download process to use multiprocessing

The previous implementation used threads to download files in parallel, which could lead to performance issues due to the Global Interpreter Lock (GIL).
This commit refactors the download process to use multiprocessing instead,
which allows for true parallelism across multiple CPUs.
This results in significantly faster downloads, particularly for large models.
This commit is contained in:
Nikita Skakun 2023-03-28 14:24:23 -07:00
parent aebd3cf110
commit 4d8e101006

View File

@ -17,13 +17,6 @@ from pathlib import Path
import requests import requests
import tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
args = parser.parse_args()
def get_file(args): def get_file(args):
url = args[0] url = args[0]
output_folder = args[1] output_folder = args[1]
@ -150,7 +143,22 @@ def get_download_links_from_huggingface(model, branch):
return links, is_lora return links, is_lora
def download_files(file_list, output_folder, num_processes=8):
with multiprocessing.Pool(processes=num_processes) as pool:
args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)]
for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)):
pass
pool.close()
pool.join()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
args = parser.parse_args()
model = args.MODEL model = args.MODEL
branch = args.branch branch = args.branch
if model is None: if model is None:
@ -179,7 +187,4 @@ if __name__ == '__main__':
# Downloading the files # Downloading the files
print(f"Downloading the model to {output_folder}") print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads) download_files(links, output_folder, num_processes=args.threads)
results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
pool.close()
pool.join()