Add --exclude-pattern flag to download-model.py script (#6542)

This commit is contained in:
Jack Cloudman 2025-01-08 14:30:21 -06:00 committed by GitHub
parent 1f86722977
commit d3adcbf64b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -72,7 +72,7 @@ class ModelDownloader:
return model, branch
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None, exclude_pattern=None):
session = self.session
page = f"/api/models/{model}/tree/{branch}"
cursor = b""
@ -100,13 +100,17 @@ class ModelDownloader:
if specific_file not in [None, ''] and fname != specific_file:
continue
# Exclude files matching the exclude pattern
if exclude_pattern is not None and re.match(exclude_pattern, fname):
continue
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True
is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname)
is_safetensors = re.match(r".*\.safetensors", fname)
is_pt = re.match(r".*\.pt", fname)
is_gguf = re.match(r'.*\.gguf', fname)
is_gguf = re.match(r".*\.gguf", fname)
is_tiktoken = re.match(r".*\.tiktoken", fname)
is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) or is_tiktoken
is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer
@ -140,7 +144,6 @@ class ModelDownloader:
# If both pytorch and safetensors are available, download safetensors only
# Also if GGUF and safetensors are available, download only safetensors
# (why do people do this?)
if (has_pytorch or has_pt or has_gguf) and has_safetensors:
has_gguf = False
for i in range(len(classifications) - 1, -1, -1):
@ -148,8 +151,6 @@ class ModelDownloader:
links.pop(i)
# For GGUF, try to download only the Q4_K_M if no specific file is specified.
# If not present, exclude all GGUFs, as that's likely a repository with both
# GGUF and fp16 files.
if has_gguf and specific_file is None:
has_q4km = False
for i in range(len(classifications) - 1, -1, -1):
@ -312,6 +313,7 @@ if __name__ == '__main__':
parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
@ -322,6 +324,7 @@ if __name__ == '__main__':
branch = args.branch
model = args.MODEL
specific_file = args.specific_file
exclude_pattern = args.exclude_pattern
if model is None:
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
@ -336,7 +339,9 @@ if __name__ == '__main__':
sys.exit()
# Get the download links from Hugging Face
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file)
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(
model, branch, text_only=args.text_only, specific_file=specific_file, exclude_pattern=exclude_pattern
)
# Get the output folder
if args.output:
@ -349,4 +354,7 @@ if __name__ == '__main__':
downloader.check_model_files(model, branch, links, sha256, output_folder)
else:
# Download files
downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp)
downloader.download_model_files(
model, branch, links, sha256, output_folder,
specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp
)