mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Update download-model.py (Allow single file download) (#3732)
This commit is contained in:
parent
dac5f4b912
commit
f63dd83631
@ -47,7 +47,7 @@ class ModelDownloader:
|
|||||||
|
|
||||||
return model, branch
|
return model, branch
|
||||||
|
|
||||||
def get_download_links_from_huggingface(self, model, branch, text_only=False):
|
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
|
||||||
base = "https://huggingface.co"
|
base = "https://huggingface.co"
|
||||||
page = f"/api/models/{model}/tree/{branch}"
|
page = f"/api/models/{model}/tree/{branch}"
|
||||||
cursor = b""
|
cursor = b""
|
||||||
@ -73,6 +73,9 @@ class ModelDownloader:
|
|||||||
|
|
||||||
for i in range(len(dict)):
|
for i in range(len(dict)):
|
||||||
fname = dict[i]['path']
|
fname = dict[i]['path']
|
||||||
|
if specific_file is not None and fname != specific_file:
|
||||||
|
continue
|
||||||
|
|
||||||
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
|
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
|
||||||
is_lora = True
|
is_lora = True
|
||||||
|
|
||||||
@ -126,12 +129,16 @@ class ModelDownloader:
|
|||||||
if classifications[i] == 'ggml':
|
if classifications[i] == 'ggml':
|
||||||
links.pop(i)
|
links.pop(i)
|
||||||
|
|
||||||
return links, sha256, is_lora
|
return links, sha256, is_lora, ((has_ggml or has_gguf) and specific_file is not None)
|
||||||
|
|
||||||
def get_output_folder(self, model, branch, is_lora, base_folder=None):
|
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, base_folder=None):
|
||||||
if base_folder is None:
|
if base_folder is None:
|
||||||
base_folder = 'models' if not is_lora else 'loras'
|
base_folder = 'models' if not is_lora else 'loras'
|
||||||
|
|
||||||
|
# If the model is of type GGUF or GGML, save directly in the base_folder
|
||||||
|
if is_llamacpp:
|
||||||
|
return Path(base_folder)
|
||||||
|
|
||||||
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
||||||
if branch != 'main':
|
if branch != 'main':
|
||||||
output_folder += f'_{branch}'
|
output_folder += f'_{branch}'
|
||||||
@ -173,7 +180,7 @@ class ModelDownloader:
|
|||||||
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
|
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
|
||||||
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
||||||
|
|
||||||
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1):
|
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1, specific_file=None):
|
||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
|
|
||||||
# Creating the folder and writing the metadata
|
# Creating the folder and writing the metadata
|
||||||
@ -189,8 +196,11 @@ class ModelDownloader:
|
|||||||
metadata += '\n'
|
metadata += '\n'
|
||||||
(output_folder / 'huggingface-metadata.txt').write_text(metadata)
|
(output_folder / 'huggingface-metadata.txt').write_text(metadata)
|
||||||
|
|
||||||
# Downloading the files
|
if specific_file:
|
||||||
print(f"Downloading the model to {output_folder}")
|
print(f"Downloading {specific_file} to {output_folder}")
|
||||||
|
else:
|
||||||
|
print(f"Downloading the model to {output_folder}")
|
||||||
|
|
||||||
self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
|
self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
|
||||||
|
|
||||||
def check_model_files(self, model, branch, links, sha256, output_folder):
|
def check_model_files(self, model, branch, links, sha256, output_folder):
|
||||||
@ -226,6 +236,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||||
|
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
||||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
@ -234,28 +245,29 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
branch = args.branch
|
branch = args.branch
|
||||||
model = args.MODEL
|
model = args.MODEL
|
||||||
|
specific_file = args.specific_file
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
|
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
downloader = ModelDownloader(max_retries=args.max_retries)
|
downloader = ModelDownloader(max_retries=args.max_retries)
|
||||||
# Cleaning up the model/branch names
|
# Clean up the model/branch names
|
||||||
try:
|
try:
|
||||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
except ValueError as err_branch:
|
except ValueError as err_branch:
|
||||||
print(f"Error: {err_branch}")
|
print(f"Error: {err_branch}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
# Getting the download links from Hugging Face
|
# Get the download links from Hugging Face
|
||||||
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only)
|
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file)
|
||||||
|
|
||||||
# Getting the output folder
|
# Get the output folder
|
||||||
output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=args.output)
|
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=args.output)
|
||||||
|
|
||||||
if args.check:
|
if args.check:
|
||||||
# Check previously downloaded files
|
# Check previously downloaded files
|
||||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
else:
|
else:
|
||||||
# Download files
|
# Download files
|
||||||
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)
|
downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads)
|
||||||
|
Loading…
Reference in New Issue
Block a user