Improve progress bar visual style

This commit reverts the performance improvements of the previous commit for for improved visual style of multithreaded progress bars. The style of the progress bar has been modified to take up the same amount of size to align them.
This commit is contained in:
Nikita Skakun 2023-03-28 18:29:20 -07:00
parent 4d8e101006
commit ff515ec2fe

View File

@ -16,23 +16,17 @@ from pathlib import Path
import requests import requests
import tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
def get_file(args): def get_file(url, output_folder):
url = args[0]
output_folder = args[1]
idx = args[2]
tot = args[3]
print(f"Downloading file {idx} of {tot}...")
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f: with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
block_size = 1024 block_size = 1024
t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
t.update(len(data)) t.update(len(data))
f.write(data) f.write(data)
t.close()
def sanitize_branch_name(branch_name): def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$") pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
@ -143,13 +137,8 @@ def get_download_links_from_huggingface(model, branch):
return links, is_lora return links, is_lora
def download_files(file_list, output_folder, num_processes=8): def download_files(file_list, output_folder, num_threads=8):
with multiprocessing.Pool(processes=num_processes) as pool: thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, verbose=False)
args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)]
for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)):
pass
pool.close()
pool.join()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -187,4 +176,4 @@ if __name__ == '__main__':
# Downloading the files # Downloading the files
print(f"Downloading the model to {output_folder}") print(f"Downloading the model to {output_folder}")
download_files(links, output_folder, num_processes=args.threads) download_files(links, output_folder, args.threads)