diff --git a/download-model.py b/download-model.py index dce7e749..48ae449e 100644 --- a/download-model.py +++ b/download-model.py @@ -17,13 +17,6 @@ from pathlib import Path import requests import tqdm -parser = argparse.ArgumentParser() -parser.add_argument('MODEL', type=str, default=None, nargs='?') -parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') -parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') -parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') -args = parser.parse_args() - def get_file(args): url = args[0] output_folder = args[1] @@ -150,7 +143,22 @@ def get_download_links_from_huggingface(model, branch): return links, is_lora +def download_files(file_list, output_folder, num_processes=8): + with multiprocessing.Pool(processes=num_processes) as pool: + args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)] + for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)): + pass + pool.close() + pool.join() + if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('MODEL', type=str, default=None, nargs='?') + parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') + parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') + parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') + args = parser.parse_args() + model = args.MODEL branch = args.branch if model is None: @@ -179,7 +187,4 @@ if __name__ == '__main__': # Downloading the files print(f"Downloading the model to {output_folder}") - pool = multiprocessing.Pool(processes=args.threads) - results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))]) - pool.close() - pool.join() + download_files(links, output_folder, num_processes=args.threads)