If both .pt and .safetensors are present, download only safetensors

This commit is contained in:
oobabooga 2023-03-28 13:08:38 -03:00
parent 8579fe51dd
commit 91aa5b460e

View File

@ -100,6 +100,7 @@ def get_download_links_from_huggingface(model, branch):
links = [] links = []
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_pt = False
has_safetensors = False has_safetensors = False
is_lora = False is_lora = False
while True: while True:
@ -115,7 +116,7 @@ def get_download_links_from_huggingface(model, branch):
is_lora = True is_lora = True
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname) is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname) is_safetensors = re.match(".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname) is_pt = re.match(".*\.pt", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname) is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
@ -134,6 +135,7 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True has_pytorch = True
classifications.append('pytorch') classifications.append('pytorch')
elif is_pt: elif is_pt:
has_pt = True
classifications.append('pt') classifications.append('pt')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
@ -141,9 +143,9 @@ def get_download_links_from_huggingface(model, branch):
cursor = cursor.replace(b'=', b'%3D') cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only # If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors: if (has_pytorch or has_pt) and has_safetensors:
for i in range(len(classifications)-1, -1, -1): for i in range(len(classifications)-1, -1, -1):
if classifications[i] == 'pytorch': if classifications[i] in ['pytorch', 'pt']:
links.pop(i) links.pop(i)
return links, is_lora return links, is_lora