mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-10 04:20:33 +01:00
Add --exclude-pattern
flag to download-model.py script (#6542)
This commit is contained in:
parent
1f86722977
commit
d3adcbf64b
@ -72,7 +72,7 @@ class ModelDownloader:
|
|||||||
|
|
||||||
return model, branch
|
return model, branch
|
||||||
|
|
||||||
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
|
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None, exclude_pattern=None):
|
||||||
session = self.session
|
session = self.session
|
||||||
page = f"/api/models/{model}/tree/{branch}"
|
page = f"/api/models/{model}/tree/{branch}"
|
||||||
cursor = b""
|
cursor = b""
|
||||||
@ -100,13 +100,17 @@ class ModelDownloader:
|
|||||||
if specific_file not in [None, ''] and fname != specific_file:
|
if specific_file not in [None, ''] and fname != specific_file:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Exclude files matching the exclude pattern
|
||||||
|
if exclude_pattern is not None and re.match(exclude_pattern, fname):
|
||||||
|
continue
|
||||||
|
|
||||||
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
|
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
|
||||||
is_lora = True
|
is_lora = True
|
||||||
|
|
||||||
is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname)
|
is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname)
|
||||||
is_safetensors = re.match(r".*\.safetensors", fname)
|
is_safetensors = re.match(r".*\.safetensors", fname)
|
||||||
is_pt = re.match(r".*\.pt", fname)
|
is_pt = re.match(r".*\.pt", fname)
|
||||||
is_gguf = re.match(r'.*\.gguf', fname)
|
is_gguf = re.match(r".*\.gguf", fname)
|
||||||
is_tiktoken = re.match(r".*\.tiktoken", fname)
|
is_tiktoken = re.match(r".*\.tiktoken", fname)
|
||||||
is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) or is_tiktoken
|
is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) or is_tiktoken
|
||||||
is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer
|
is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer
|
||||||
@ -140,7 +144,6 @@ class ModelDownloader:
|
|||||||
|
|
||||||
# If both pytorch and safetensors are available, download safetensors only
|
# If both pytorch and safetensors are available, download safetensors only
|
||||||
# Also if GGUF and safetensors are available, download only safetensors
|
# Also if GGUF and safetensors are available, download only safetensors
|
||||||
# (why do people do this?)
|
|
||||||
if (has_pytorch or has_pt or has_gguf) and has_safetensors:
|
if (has_pytorch or has_pt or has_gguf) and has_safetensors:
|
||||||
has_gguf = False
|
has_gguf = False
|
||||||
for i in range(len(classifications) - 1, -1, -1):
|
for i in range(len(classifications) - 1, -1, -1):
|
||||||
@ -148,8 +151,6 @@ class ModelDownloader:
|
|||||||
links.pop(i)
|
links.pop(i)
|
||||||
|
|
||||||
# For GGUF, try to download only the Q4_K_M if no specific file is specified.
|
# For GGUF, try to download only the Q4_K_M if no specific file is specified.
|
||||||
# If not present, exclude all GGUFs, as that's likely a repository with both
|
|
||||||
# GGUF and fp16 files.
|
|
||||||
if has_gguf and specific_file is None:
|
if has_gguf and specific_file is None:
|
||||||
has_q4km = False
|
has_q4km = False
|
||||||
for i in range(len(classifications) - 1, -1, -1):
|
for i in range(len(classifications) - 1, -1, -1):
|
||||||
@ -312,6 +313,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.')
|
parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.')
|
||||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||||
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
||||||
|
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
|
||||||
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
|
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
|
||||||
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
|
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
|
||||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
@ -322,6 +324,7 @@ if __name__ == '__main__':
|
|||||||
branch = args.branch
|
branch = args.branch
|
||||||
model = args.MODEL
|
model = args.MODEL
|
||||||
specific_file = args.specific_file
|
specific_file = args.specific_file
|
||||||
|
exclude_pattern = args.exclude_pattern
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
|
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
|
||||||
@ -336,7 +339,9 @@ if __name__ == '__main__':
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
# Get the download links from Hugging Face
|
# Get the download links from Hugging Face
|
||||||
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file)
|
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(
|
||||||
|
model, branch, text_only=args.text_only, specific_file=specific_file, exclude_pattern=exclude_pattern
|
||||||
|
)
|
||||||
|
|
||||||
# Get the output folder
|
# Get the output folder
|
||||||
if args.output:
|
if args.output:
|
||||||
@ -349,4 +354,7 @@ if __name__ == '__main__':
|
|||||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
else:
|
else:
|
||||||
# Download files
|
# Download files
|
||||||
downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp)
|
downloader.download_model_files(
|
||||||
|
model, branch, links, sha256, output_folder,
|
||||||
|
specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user