Add --threads option to the download script

This commit is contained in:
oobabooga 2023-02-03 18:57:12 -03:00
parent 03f084f311
commit 9215e281ba

View File

@ -18,12 +18,16 @@ import re
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str)
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
args = parser.parse_args()
def get_file(args):
url = args[0]
output_folder = args[1]
idx = args[2]
tot = args[3]
print(f"Downloading file {idx} of {tot}...")
r = requests.get(url, stream=True)
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
total_size = int(r.headers.get('content-length', 0))
@ -77,8 +81,8 @@ if __name__ == '__main__':
downloads.append(f'https://huggingface.co/{href}')
# Downloading the files
print(f"Downloading the model to {output_folder}...")
pool = multiprocessing.Pool(processes=4)
results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads)
results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
pool.close()
pool.join()