mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 19:09:32 +01:00
Add --threads option to the download script
This commit is contained in:
parent
03f084f311
commit
9215e281ba
@ -18,12 +18,16 @@ import re
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('MODEL', type=str)
|
parser.add_argument('MODEL', type=str)
|
||||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||||
|
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def get_file(args):
|
def get_file(args):
|
||||||
url = args[0]
|
url = args[0]
|
||||||
output_folder = args[1]
|
output_folder = args[1]
|
||||||
|
idx = args[2]
|
||||||
|
tot = args[3]
|
||||||
|
|
||||||
|
print(f"Downloading file {idx} of {tot}...")
|
||||||
r = requests.get(url, stream=True)
|
r = requests.get(url, stream=True)
|
||||||
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
|
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
@ -77,8 +81,8 @@ if __name__ == '__main__':
|
|||||||
downloads.append(f'https://huggingface.co/{href}')
|
downloads.append(f'https://huggingface.co/{href}')
|
||||||
|
|
||||||
# Downloading the files
|
# Downloading the files
|
||||||
print(f"Downloading the model to {output_folder}...")
|
print(f"Downloading the model to {output_folder}")
|
||||||
pool = multiprocessing.Pool(processes=4)
|
pool = multiprocessing.Pool(processes=args.threads)
|
||||||
results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
|
results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
|
||||||
pool.close()
|
pool.close()
|
||||||
pool.join()
|
pool.join()
|
||||||
|
Loading…
Reference in New Issue
Block a user