mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-09 20:19:06 +01:00
Add --exclude-pattern
flag to download-model.py script (#6542)
This commit is contained in:
parent
1f86722977
commit
d3adcbf64b
@ -72,7 +72,7 @@ class ModelDownloader:
|
||||
|
||||
return model, branch
|
||||
|
||||
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
|
||||
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None, exclude_pattern=None):
|
||||
session = self.session
|
||||
page = f"/api/models/{model}/tree/{branch}"
|
||||
cursor = b""
|
||||
@ -100,13 +100,17 @@ class ModelDownloader:
|
||||
if specific_file not in [None, ''] and fname != specific_file:
|
||||
continue
|
||||
|
||||
# Exclude files matching the exclude pattern
|
||||
if exclude_pattern is not None and re.match(exclude_pattern, fname):
|
||||
continue
|
||||
|
||||
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
|
||||
is_lora = True
|
||||
|
||||
is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname)
|
||||
is_safetensors = re.match(r".*\.safetensors", fname)
|
||||
is_pt = re.match(r".*\.pt", fname)
|
||||
is_gguf = re.match(r'.*\.gguf', fname)
|
||||
is_gguf = re.match(r".*\.gguf", fname)
|
||||
is_tiktoken = re.match(r".*\.tiktoken", fname)
|
||||
is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) or is_tiktoken
|
||||
is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer
|
||||
@ -140,7 +144,6 @@ class ModelDownloader:
|
||||
|
||||
# If both pytorch and safetensors are available, download safetensors only
|
||||
# Also if GGUF and safetensors are available, download only safetensors
|
||||
# (why do people do this?)
|
||||
if (has_pytorch or has_pt or has_gguf) and has_safetensors:
|
||||
has_gguf = False
|
||||
for i in range(len(classifications) - 1, -1, -1):
|
||||
@ -148,8 +151,6 @@ class ModelDownloader:
|
||||
links.pop(i)
|
||||
|
||||
# For GGUF, try to download only the Q4_K_M if no specific file is specified.
|
||||
# If not present, exclude all GGUFs, as that's likely a repository with both
|
||||
# GGUF and fp16 files.
|
||||
if has_gguf and specific_file is None:
|
||||
has_q4km = False
|
||||
for i in range(len(classifications) - 1, -1, -1):
|
||||
@ -312,6 +313,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.')
|
||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
||||
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
|
||||
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
|
||||
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
@ -322,6 +324,7 @@ if __name__ == '__main__':
|
||||
branch = args.branch
|
||||
model = args.MODEL
|
||||
specific_file = args.specific_file
|
||||
exclude_pattern = args.exclude_pattern
|
||||
|
||||
if model is None:
|
||||
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
|
||||
@ -336,7 +339,9 @@ if __name__ == '__main__':
|
||||
sys.exit()
|
||||
|
||||
# Get the download links from Hugging Face
|
||||
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file)
|
||||
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(
|
||||
model, branch, text_only=args.text_only, specific_file=specific_file, exclude_pattern=exclude_pattern
|
||||
)
|
||||
|
||||
# Get the output folder
|
||||
if args.output:
|
||||
@ -349,4 +354,7 @@ if __name__ == '__main__':
|
||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||
else:
|
||||
# Download files
|
||||
downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp)
|
||||
downloader.download_model_files(
|
||||
model, branch, links, sha256, output_folder,
|
||||
specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user