Add a retry mechanism to the model downloader (#5943)

This commit is contained in:
oobabooga 2024-04-27 12:25:28 -03:00 committed by GitHub
parent dfdb6fee22
commit 5770e06c48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,10 +15,12 @@ import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from time import sleep
import requests import requests
import tqdm import tqdm
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co" base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
@ -177,25 +179,30 @@ class ModelDownloader:
return output_folder return output_folder
def get_single_file(self, url, output_folder, start_from_scratch=False): def get_single_file(self, url, output_folder, start_from_scratch=False):
session = self.get_session()
filename = Path(url.rsplit('/', 1)[1]) filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename output_path = output_folder / filename
max_retries = 7
attempt = 0
while attempt < max_retries:
attempt += 1
session = self.get_session()
headers = {} headers = {}
mode = 'wb' mode = 'wb'
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely if output_path.exists() and not start_from_scratch:
r = session.get(url, stream=True, timeout=10) # Resume download
r = session.get(url, stream=True, timeout=20)
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size: if output_path.stat().st_size >= total_size:
return return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'} headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab' mode = 'ab'
with session.get(url, stream=True, headers=headers, timeout=10) as r: try:
r.raise_for_status() # Do not continue the download if the request was unsuccessful with session.get(url, stream=True, headers=headers, timeout=30) as r:
r.raise_for_status() # If status is not 2xx, raise an error
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB block_size = 1024 * 1024 # 1MB
@ -203,7 +210,7 @@ class ModelDownloader:
'total': total_size, 'total': total_size,
'unit': 'iB', 'unit': 'iB',
'unit_scale': True, 'unit_scale': True,
'bar_format': '{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}' 'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt}'
} }
if 'COLAB_GPU' in os.environ: if 'COLAB_GPU' in os.environ:
@ -216,12 +223,22 @@ class ModelDownloader:
with tqdm.tqdm(**tqdm_kwargs) as t: with tqdm.tqdm(**tqdm_kwargs) as t:
count = 0 count = 0
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
t.update(len(data))
f.write(data) f.write(data)
t.update(len(data))
if total_size != 0 and self.progress_bar is not None: if total_size != 0 and self.progress_bar is not None:
count += len(data) count += len(data)
self.progress_bar(float(count) / float(total_size), f"{filename}") self.progress_bar(float(count) / float(total_size), f"{filename}")
break # Exit loop if successful
except (RequestException, ConnectionError, Timeout) as e:
print(f"Error downloading {filename}: {e}.")
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
if attempt < max_retries:
print(f"Retry begins in {2 ** attempt} seconds.")
sleep(2 ** attempt)
else:
print("Failed to download after the maximum number of attempts.")
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4): def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)