mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Download models with 4 threads by default
This commit is contained in:
parent
520cbb2ab1
commit
3a9d90c3a1
@ -177,10 +177,10 @@ class ModelDownloader:
|
|||||||
count += len(data)
|
count += len(data)
|
||||||
self.progress_bar(float(count) / float(total_size), f"{filename}")
|
self.progress_bar(float(count) / float(total_size), f"{filename}")
|
||||||
|
|
||||||
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
|
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
|
||||||
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
||||||
|
|
||||||
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1, specific_file=None, is_llamacpp=False):
|
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False):
|
||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
|
|
||||||
# Create the folder and writing the metadata
|
# Create the folder and writing the metadata
|
||||||
@ -236,7 +236,7 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.')
|
||||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||||
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
||||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||||
|
@ -225,16 +225,11 @@ def load_lora_wrapper(selected_loras):
|
|||||||
|
|
||||||
def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False):
|
def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False):
|
||||||
try:
|
try:
|
||||||
downloader_module = importlib.import_module("download-model")
|
|
||||||
downloader = downloader_module.ModelDownloader()
|
|
||||||
|
|
||||||
progress(0.0)
|
progress(0.0)
|
||||||
yield ("Cleaning up the model/branch names")
|
downloader = importlib.import_module("download-model").ModelDownloader()
|
||||||
model, branch = downloader.sanitize_model_and_branch_names(repo_id, None)
|
model, branch = downloader.sanitize_model_and_branch_names(repo_id, None)
|
||||||
|
|
||||||
yield ("Getting the download links from Hugging Face")
|
yield ("Getting the download links from Hugging Face")
|
||||||
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file)
|
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file)
|
||||||
|
|
||||||
if return_links:
|
if return_links:
|
||||||
yield '\n\n'.join([f"`{Path(link).name}`" for link in links])
|
yield '\n\n'.join([f"`{Path(link).name}`" for link in links])
|
||||||
return
|
return
|
||||||
@ -242,7 +237,6 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
|||||||
yield ("Getting the output folder")
|
yield ("Getting the output folder")
|
||||||
base_folder = shared.args.lora_dir if is_lora else shared.args.model_dir
|
base_folder = shared.args.lora_dir if is_lora else shared.args.model_dir
|
||||||
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=base_folder)
|
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=base_folder)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
progress(0.5)
|
progress(0.5)
|
||||||
yield ("Checking previously downloaded files")
|
yield ("Checking previously downloaded files")
|
||||||
@ -250,7 +244,7 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
|||||||
progress(1.0)
|
progress(1.0)
|
||||||
else:
|
else:
|
||||||
yield (f"Downloading file{'s' if len(links) > 1 else ''} to `{output_folder}/`")
|
yield (f"Downloading file{'s' if len(links) > 1 else ''} to `{output_folder}/`")
|
||||||
downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1, is_llamacpp=is_llamacpp)
|
downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=4, is_llamacpp=is_llamacpp)
|
||||||
yield ("Done!")
|
yield ("Done!")
|
||||||
except:
|
except:
|
||||||
progress(1.0)
|
progress(1.0)
|
||||||
|
Loading…
Reference in New Issue
Block a user