Add a download progress bar to the web UI. (#2472)

* Show download progress on the model screen.

* In case of error, mark as done to clear progress bar.

* Increase the iteration block size to reduce overhead.
This commit is contained in:
Morgan Schweers 2023-06-20 18:59:14 -07:00 committed by GitHub
parent 0d0d849478
commit 447569e31a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 5 deletions

View File

@ -194,18 +194,25 @@ class ModelDownloader:
r = self.s.get(url, stream=True, headers=headers, timeout=20) r = self.s.get(url, stream=True, headers=headers, timeout=20)
with open(output_path, mode) as f: with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
block_size = 1024 # Every 4MB we report an update
block_size = 4*1024*1024
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
count = 0
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
t.update(len(data)) t.update(len(data))
f.write(data) f.write(data)
if self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count)/float(total_size), f"Downloading {filename}")
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1): def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
def download_model_files(self, model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1): def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar = None, start_from_scratch=False, threads=1):
self.progress_bar = progress_bar
# Creating the folder and writing the metadata # Creating the folder and writing the metadata
if not output_folder.exists(): if not output_folder.exists():
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)

View File

@ -122,7 +122,7 @@ def count_tokens(text):
return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?' return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?'
def download_model_wrapper(repo_id): def download_model_wrapper(repo_id, progress=gr.Progress()):
try: try:
downloader_module = importlib.import_module("download-model") downloader_module = importlib.import_module("download-model")
downloader = downloader_module.ModelDownloader() downloader = downloader_module.ModelDownloader()
@ -131,6 +131,7 @@ def download_model_wrapper(repo_id):
branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main" branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
check = False check = False
progress(0.0)
yield ("Cleaning up the model/branch names") yield ("Cleaning up the model/branch names")
model, branch = downloader.sanitize_model_and_branch_names(model, branch) model, branch = downloader.sanitize_model_and_branch_names(model, branch)
@ -141,13 +142,16 @@ def download_model_wrapper(repo_id):
output_folder = downloader.get_output_folder(model, branch, is_lora) output_folder = downloader.get_output_folder(model, branch, is_lora)
if check: if check:
progress(0.5)
yield ("Checking previously downloaded files") yield ("Checking previously downloaded files")
downloader.check_model_files(model, branch, links, sha256, output_folder) downloader.check_model_files(model, branch, links, sha256, output_folder)
progress(1.0)
else: else:
yield (f"Downloading files to {output_folder}") yield (f"Downloading files to {output_folder}")
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1) downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1)
yield ("Done!") yield ("Done!")
except: except:
progress(1.0)
yield traceback.format_exc() yield traceback.format_exc()
@ -276,7 +280,7 @@ def create_model_menus():
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False) save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=True)
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)