mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-01 20:04:04 +01:00
Add a download progress bar to the web UI. (#2472)
* Show download progress on the model screen. * In case of error, mark as done to clear progress bar. * Increase the iteration block size to reduce overhead.
This commit is contained in:
parent
0d0d849478
commit
447569e31a
@ -194,18 +194,25 @@ class ModelDownloader:
|
|||||||
r = self.s.get(url, stream=True, headers=headers, timeout=20)
|
r = self.s.get(url, stream=True, headers=headers, timeout=20)
|
||||||
with open(output_path, mode) as f:
|
with open(output_path, mode) as f:
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
block_size = 1024
|
# Every 4MB we report an update
|
||||||
|
block_size = 4*1024*1024
|
||||||
|
|
||||||
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||||
|
count = 0
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
t.update(len(data))
|
t.update(len(data))
|
||||||
f.write(data)
|
f.write(data)
|
||||||
|
if self.progress_bar is not None:
|
||||||
|
count += len(data)
|
||||||
|
self.progress_bar(float(count)/float(total_size), f"Downloading {filename}")
|
||||||
|
|
||||||
|
|
||||||
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
|
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
|
||||||
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
||||||
|
|
||||||
|
|
||||||
def download_model_files(self, model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
|
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar = None, start_from_scratch=False, threads=1):
|
||||||
|
self.progress_bar = progress_bar
|
||||||
# Creating the folder and writing the metadata
|
# Creating the folder and writing the metadata
|
||||||
if not output_folder.exists():
|
if not output_folder.exists():
|
||||||
output_folder.mkdir(parents=True, exist_ok=True)
|
output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
10
server.py
10
server.py
@ -122,7 +122,7 @@ def count_tokens(text):
|
|||||||
return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?'
|
return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?'
|
||||||
|
|
||||||
|
|
||||||
def download_model_wrapper(repo_id):
|
def download_model_wrapper(repo_id, progress=gr.Progress()):
|
||||||
try:
|
try:
|
||||||
downloader_module = importlib.import_module("download-model")
|
downloader_module = importlib.import_module("download-model")
|
||||||
downloader = downloader_module.ModelDownloader()
|
downloader = downloader_module.ModelDownloader()
|
||||||
@ -131,6 +131,7 @@ def download_model_wrapper(repo_id):
|
|||||||
branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
|
branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
|
||||||
check = False
|
check = False
|
||||||
|
|
||||||
|
progress(0.0)
|
||||||
yield ("Cleaning up the model/branch names")
|
yield ("Cleaning up the model/branch names")
|
||||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
|
|
||||||
@ -141,13 +142,16 @@ def download_model_wrapper(repo_id):
|
|||||||
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
|
progress(0.5)
|
||||||
yield ("Checking previously downloaded files")
|
yield ("Checking previously downloaded files")
|
||||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
|
progress(1.0)
|
||||||
else:
|
else:
|
||||||
yield (f"Downloading files to {output_folder}")
|
yield (f"Downloading files to {output_folder}")
|
||||||
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1)
|
||||||
yield ("Done!")
|
yield ("Done!")
|
||||||
except:
|
except:
|
||||||
|
progress(1.0)
|
||||||
yield traceback.format_exc()
|
yield traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
@ -276,7 +280,7 @@ def create_model_menus():
|
|||||||
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
|
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
|
||||||
|
|
||||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
||||||
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=True)
|
||||||
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user