mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-27 20:43:19 +01:00
Downloader: Make progress bars not jump around
Adapted from: https://gist.github.com/NiklasBeierl/13096bfdd8b2084da8c1163dd06f91d3
This commit is contained in:
parent
71a551a622
commit
3d4f3e423c
@ -14,6 +14,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from multiprocessing import Array
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
@ -27,9 +28,10 @@ base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
|
|||||||
|
|
||||||
|
|
||||||
class ModelDownloader:
|
class ModelDownloader:
|
||||||
def __init__(self, max_retries=5):
|
def __init__(self, max_retries=7):
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.session = self.get_session()
|
self.session = self.get_session()
|
||||||
|
self._progress_bar_slots = None
|
||||||
|
|
||||||
def get_session(self):
|
def get_session(self):
|
||||||
session = requests.Session()
|
session = requests.Session()
|
||||||
@ -186,73 +188,112 @@ class ModelDownloader:
|
|||||||
output_folder = Path(base_folder) / output_folder
|
output_folder = Path(base_folder) / output_folder
|
||||||
return output_folder
|
return output_folder
|
||||||
|
|
||||||
|
@property
|
||||||
|
def progress_bar_slots(self):
|
||||||
|
if self._progress_bar_slots is None:
|
||||||
|
raise RuntimeError("Progress bar slots not initialized. Start download threads first.")
|
||||||
|
|
||||||
|
return self._progress_bar_slots
|
||||||
|
|
||||||
|
def initialize_progress_bar_slots(self, num_threads):
|
||||||
|
self._progress_bar_slots = Array("B", [0] * num_threads)
|
||||||
|
|
||||||
|
def get_progress_bar_position(self):
|
||||||
|
with self.progress_bar_slots.get_lock():
|
||||||
|
for i in range(len(self.progress_bar_slots)):
|
||||||
|
if self.progress_bar_slots[i] == 0:
|
||||||
|
self.progress_bar_slots[i] = 1
|
||||||
|
return i
|
||||||
|
|
||||||
|
return 0 # fallback
|
||||||
|
|
||||||
|
def release_progress_bar_position(self, slot):
|
||||||
|
with self.progress_bar_slots.get_lock():
|
||||||
|
self.progress_bar_slots[slot] = 0
|
||||||
|
|
||||||
def get_single_file(self, url, output_folder, start_from_scratch=False):
|
def get_single_file(self, url, output_folder, start_from_scratch=False):
|
||||||
filename = Path(url.rsplit('/', 1)[1])
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
output_path = output_folder / filename
|
output_path = output_folder / filename
|
||||||
|
progress_bar_position = self.get_progress_bar_position()
|
||||||
|
|
||||||
max_retries = 7
|
max_retries = self.max_retries
|
||||||
attempt = 0
|
attempt = 0
|
||||||
while attempt < max_retries:
|
try:
|
||||||
attempt += 1
|
while attempt < max_retries:
|
||||||
session = self.session
|
attempt += 1
|
||||||
headers = {}
|
session = self.session
|
||||||
mode = 'wb'
|
headers = {}
|
||||||
|
mode = 'wb'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if output_path.exists() and not start_from_scratch:
|
if output_path.exists() and not start_from_scratch:
|
||||||
# Resume download
|
# Resume download
|
||||||
r = session.get(url, stream=True, timeout=20)
|
r = session.get(url, stream=True, timeout=20)
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
if output_path.stat().st_size >= total_size:
|
if output_path.stat().st_size >= total_size:
|
||||||
return
|
return
|
||||||
|
|
||||||
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||||
mode = 'ab'
|
mode = 'ab'
|
||||||
|
|
||||||
with session.get(url, stream=True, headers=headers, timeout=30) as r:
|
with session.get(url, stream=True, headers=headers, timeout=30) as r:
|
||||||
r.raise_for_status() # If status is not 2xx, raise an error
|
r.raise_for_status() # If status is not 2xx, raise an error
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
block_size = 1024 * 1024 # 1MB
|
block_size = 1024 * 1024 # 1MB
|
||||||
|
|
||||||
filename_str = str(filename) # Convert PosixPath to string if necessary
|
filename_str = str(filename) # Convert PosixPath to string if necessary
|
||||||
|
|
||||||
tqdm_kwargs = {
|
tqdm_kwargs = {
|
||||||
'total': total_size,
|
'total': total_size,
|
||||||
'unit': 'B',
|
'unit': 'B',
|
||||||
'unit_scale': True,
|
'unit_scale': True,
|
||||||
'unit_divisor': 1024,
|
'unit_divisor': 1024,
|
||||||
'bar_format': '{desc}{percentage:3.0f}%|{bar:50}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
|
'bar_format': '{desc}{percentage:3.0f}%|{bar:50}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
|
||||||
'desc': f"{filename_str}: "
|
'desc': f"{filename_str}: ",
|
||||||
}
|
'position': progress_bar_position,
|
||||||
|
'leave': False
|
||||||
|
}
|
||||||
|
|
||||||
if 'COLAB_GPU' in os.environ:
|
if 'COLAB_GPU' in os.environ:
|
||||||
tqdm_kwargs.update({
|
tqdm_kwargs.update({
|
||||||
'position': 0,
|
'position': 0,
|
||||||
'leave': True
|
'leave': True
|
||||||
})
|
})
|
||||||
|
|
||||||
with open(output_path, mode) as f:
|
with open(output_path, mode) as f:
|
||||||
with tqdm.tqdm(**tqdm_kwargs) as t:
|
with tqdm.tqdm(**tqdm_kwargs) as t:
|
||||||
count = 0
|
count = 0
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
f.write(data)
|
f.write(data)
|
||||||
t.update(len(data))
|
t.update(len(data))
|
||||||
if total_size != 0 and self.progress_bar is not None:
|
if total_size != 0 and self.progress_bar is not None:
|
||||||
count += len(data)
|
count += len(data)
|
||||||
self.progress_bar(float(count) / float(total_size), f"{filename_str}")
|
self.progress_bar(float(count) / float(total_size), f"{filename_str}")
|
||||||
|
|
||||||
break # Exit loop if successful
|
break # Exit loop if successful
|
||||||
except (RequestException, ConnectionError, Timeout) as e:
|
except (RequestException, ConnectionError, Timeout) as e:
|
||||||
print(f"Error downloading {filename}: {e}.")
|
print(f"Error downloading {filename}: {e}.")
|
||||||
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
|
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
print(f"Retry begins in {2 ** attempt} seconds.")
|
print(f"Retry begins in {2 ** attempt} seconds.")
|
||||||
sleep(2 ** attempt)
|
sleep(2 ** attempt)
|
||||||
else:
|
else:
|
||||||
print("Failed to download after the maximum number of attempts.")
|
print("Failed to download after the maximum number of attempts.")
|
||||||
|
finally:
|
||||||
|
self.release_progress_bar_position(progress_bar_position)
|
||||||
|
|
||||||
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
|
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
|
||||||
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
self.initialize_progress_bar_slots(threads)
|
||||||
|
tqdm.tqdm.set_lock(tqdm.tqdm.get_lock())
|
||||||
|
try:
|
||||||
|
thread_map(
|
||||||
|
lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch),
|
||||||
|
file_list,
|
||||||
|
max_workers=threads,
|
||||||
|
disable=True
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
print(f"\nDownload of {len(file_list)} files to {output_folder} completed.")
|
||||||
|
|
||||||
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False):
|
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False):
|
||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
@ -318,7 +359,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
|
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
|
||||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.')
|
parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
branch = args.branch
|
branch = args.branch
|
||||||
|
Loading…
Reference in New Issue
Block a user