mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Add --threads option to the download script
This commit is contained in:
parent
03f084f311
commit
9215e281ba
@ -18,12 +18,16 @@ import re
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('MODEL', type=str)
|
||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||
args = parser.parse_args()
|
||||
|
||||
def get_file(args):
|
||||
url = args[0]
|
||||
output_folder = args[1]
|
||||
idx = args[2]
|
||||
tot = args[3]
|
||||
|
||||
print(f"Downloading file {idx} of {tot}...")
|
||||
r = requests.get(url, stream=True)
|
||||
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
@ -77,8 +81,8 @@ if __name__ == '__main__':
|
||||
downloads.append(f'https://huggingface.co/{href}')
|
||||
|
||||
# Downloading the files
|
||||
print(f"Downloading the model to {output_folder}...")
|
||||
pool = multiprocessing.Pool(processes=4)
|
||||
results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
|
||||
print(f"Downloading the model to {output_folder}")
|
||||
pool = multiprocessing.Pool(processes=args.threads)
|
||||
results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
Loading…
Reference in New Issue
Block a user