mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Merge pull request #618 from nikita-skakun/optimize-download-model
Improve download-model.py progress bar with multiple threads
This commit is contained in:
commit
9104164297
@ -10,13 +10,13 @@ import argparse
|
|||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||||
@ -26,22 +26,15 @@ parser.add_argument('--text-only', action='store_true', help='Only download text
|
|||||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def get_file(args):
|
def get_file(url, output_folder):
|
||||||
url = args[0]
|
|
||||||
output_folder = args[1]
|
|
||||||
idx = args[2]
|
|
||||||
tot = args[3]
|
|
||||||
|
|
||||||
print(f"Downloading file {idx} of {tot}...")
|
|
||||||
r = requests.get(url, stream=True)
|
r = requests.get(url, stream=True)
|
||||||
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
|
with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
block_size = 1024
|
block_size = 1024
|
||||||
t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
|
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
t.update(len(data))
|
t.update(len(data))
|
||||||
f.write(data)
|
f.write(data)
|
||||||
t.close()
|
|
||||||
|
|
||||||
def sanitize_branch_name(branch_name):
|
def sanitize_branch_name(branch_name):
|
||||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||||
@ -152,6 +145,9 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
|
|
||||||
return links, is_lora
|
return links, is_lora
|
||||||
|
|
||||||
|
def download_files(file_list, output_folder, num_threads=8):
|
||||||
|
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, verbose=False)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = args.MODEL
|
model = args.MODEL
|
||||||
branch = args.branch
|
branch = args.branch
|
||||||
@ -192,7 +188,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Downloading the files
|
# Downloading the files
|
||||||
print(f"Downloading the model to {output_folder}")
|
print(f"Downloading the model to {output_folder}")
|
||||||
pool = multiprocessing.Pool(processes=args.threads)
|
download_files(links, output_folder, args.threads)
|
||||||
results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
|
|
||||||
pool.close()
|
|
||||||
pool.join()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user