From 3d4f3e423c28694d35fdc431e37028c4a201de38 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Jan 2025 07:28:31 -0800 Subject: [PATCH] Downloader: Make progress bars not jump around Adapted from: https://gist.github.com/NiklasBeierl/13096bfdd8b2084da8c1163dd06f91d3 --- download-model.py | 149 +++++++++++++++++++++++++++++----------------- 1 file changed, 95 insertions(+), 54 deletions(-) diff --git a/download-model.py b/download-model.py index 8fe94371..8ff1d69c 100644 --- a/download-model.py +++ b/download-model.py @@ -14,6 +14,7 @@ import json import os import re import sys +from multiprocessing import Array from pathlib import Path from time import sleep @@ -27,9 +28,10 @@ base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co" class ModelDownloader: - def __init__(self, max_retries=5): + def __init__(self, max_retries=7): self.max_retries = max_retries self.session = self.get_session() + self._progress_bar_slots = None def get_session(self): session = requests.Session() @@ -186,73 +188,112 @@ class ModelDownloader: output_folder = Path(base_folder) / output_folder return output_folder + @property + def progress_bar_slots(self): + if self._progress_bar_slots is None: + raise RuntimeError("Progress bar slots not initialized. Start download threads first.") + + return self._progress_bar_slots + + def initialize_progress_bar_slots(self, num_threads): + self._progress_bar_slots = Array("B", [0] * num_threads) + + def get_progress_bar_position(self): + with self.progress_bar_slots.get_lock(): + for i in range(len(self.progress_bar_slots)): + if self.progress_bar_slots[i] == 0: + self.progress_bar_slots[i] = 1 + return i + + return 0 # fallback + + def release_progress_bar_position(self, slot): + with self.progress_bar_slots.get_lock(): + self.progress_bar_slots[slot] = 0 + def get_single_file(self, url, output_folder, start_from_scratch=False): filename = Path(url.rsplit('/', 1)[1]) output_path = output_folder / filename + progress_bar_position = self.get_progress_bar_position() - max_retries = 7 + max_retries = self.max_retries attempt = 0 - while attempt < max_retries: - attempt += 1 - session = self.session - headers = {} - mode = 'wb' + try: + while attempt < max_retries: + attempt += 1 + session = self.session + headers = {} + mode = 'wb' - try: - if output_path.exists() and not start_from_scratch: - # Resume download - r = session.get(url, stream=True, timeout=20) - total_size = int(r.headers.get('content-length', 0)) - if output_path.stat().st_size >= total_size: - return + try: + if output_path.exists() and not start_from_scratch: + # Resume download + r = session.get(url, stream=True, timeout=20) + total_size = int(r.headers.get('content-length', 0)) + if output_path.stat().st_size >= total_size: + return - headers = {'Range': f'bytes={output_path.stat().st_size}-'} - mode = 'ab' + headers = {'Range': f'bytes={output_path.stat().st_size}-'} + mode = 'ab' - with session.get(url, stream=True, headers=headers, timeout=30) as r: - r.raise_for_status() # If status is not 2xx, raise an error - total_size = int(r.headers.get('content-length', 0)) - block_size = 1024 * 1024 # 1MB + with session.get(url, stream=True, headers=headers, timeout=30) as r: + r.raise_for_status() # If status is not 2xx, raise an error + total_size = int(r.headers.get('content-length', 0)) + block_size = 1024 * 1024 # 1MB - filename_str = str(filename) # Convert PosixPath to string if necessary + filename_str = str(filename) # Convert PosixPath to string if necessary - tqdm_kwargs = { - 'total': total_size, - 'unit': 'B', - 'unit_scale': True, - 'unit_divisor': 1024, - 'bar_format': '{desc}{percentage:3.0f}%|{bar:50}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]', - 'desc': f"{filename_str}: " - } + tqdm_kwargs = { + 'total': total_size, + 'unit': 'B', + 'unit_scale': True, + 'unit_divisor': 1024, + 'bar_format': '{desc}{percentage:3.0f}%|{bar:50}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]', + 'desc': f"{filename_str}: ", + 'position': progress_bar_position, + 'leave': False + } - if 'COLAB_GPU' in os.environ: - tqdm_kwargs.update({ - 'position': 0, - 'leave': True - }) + if 'COLAB_GPU' in os.environ: + tqdm_kwargs.update({ + 'position': 0, + 'leave': True + }) - with open(output_path, mode) as f: - with tqdm.tqdm(**tqdm_kwargs) as t: - count = 0 - for data in r.iter_content(block_size): - f.write(data) - t.update(len(data)) - if total_size != 0 and self.progress_bar is not None: - count += len(data) - self.progress_bar(float(count) / float(total_size), f"{filename_str}") + with open(output_path, mode) as f: + with tqdm.tqdm(**tqdm_kwargs) as t: + count = 0 + for data in r.iter_content(block_size): + f.write(data) + t.update(len(data)) + if total_size != 0 and self.progress_bar is not None: + count += len(data) + self.progress_bar(float(count) / float(total_size), f"{filename_str}") - break # Exit loop if successful - except (RequestException, ConnectionError, Timeout) as e: - print(f"Error downloading {filename}: {e}.") - print(f"That was attempt {attempt}/{max_retries}.", end=' ') - if attempt < max_retries: - print(f"Retry begins in {2 ** attempt} seconds.") - sleep(2 ** attempt) - else: - print("Failed to download after the maximum number of attempts.") + break # Exit loop if successful + except (RequestException, ConnectionError, Timeout) as e: + print(f"Error downloading {filename}: {e}.") + print(f"That was attempt {attempt}/{max_retries}.", end=' ') + if attempt < max_retries: + print(f"Retry begins in {2 ** attempt} seconds.") + sleep(2 ** attempt) + else: + print("Failed to download after the maximum number of attempts.") + finally: + self.release_progress_bar_position(progress_bar_position) def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4): - thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) + self.initialize_progress_bar_slots(threads) + tqdm.tqdm.set_lock(tqdm.tqdm.get_lock()) + try: + thread_map( + lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), + file_list, + max_workers=threads, + disable=True + ) + finally: + print(f"\nDownload of {len(file_list)} files to {output_folder} completed.") def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False): self.progress_bar = progress_bar @@ -318,7 +359,7 @@ if __name__ == '__main__': parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') - parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.') + parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.') args = parser.parse_args() branch = args.branch