Add --threads option to the download script

This commit is contained in:
oobabooga 2023-02-03 18:57:12 -03:00
parent 03f084f311
commit 9215e281ba

View File

@ -18,12 +18,16 @@ import re
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str) parser.add_argument('MODEL', type=str)
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
args = parser.parse_args() args = parser.parse_args()
def get_file(args): def get_file(args):
url = args[0] url = args[0]
output_folder = args[1] output_folder = args[1]
idx = args[2]
tot = args[3]
print(f"Downloading file {idx} of {tot}...")
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f: with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
@ -77,8 +81,8 @@ if __name__ == '__main__':
downloads.append(f'https://huggingface.co/{href}') downloads.append(f'https://huggingface.co/{href}')
# Downloading the files # Downloading the files
print(f"Downloading the model to {output_folder}...") print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=4) pool = multiprocessing.Pool(processes=args.threads)
results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))]) results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
pool.close() pool.close()
pool.join() pool.join()