mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 05:48:55 +01:00
Add support for resuming downloads (#654 from nikita-skakun/support-partial-downloads)
This commit is contained in:
commit
23116b88ef
@ -9,6 +9,7 @@ python download-model.py facebook/opt-1.3b
|
|||||||
import argparse
|
import argparse
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
@ -24,11 +25,28 @@ parser.add_argument('--branch', type=str, default='main', help='Name of the Git
|
|||||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||||
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def get_file(url, output_folder):
|
def get_file(url, output_folder):
|
||||||
r = requests.get(url, stream=True)
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
|
output_path = output_folder / filename
|
||||||
|
if output_path.exists() and not args.clean:
|
||||||
|
# Check if the file has already been downloaded completely
|
||||||
|
r = requests.get(url, stream=True)
|
||||||
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
|
if output_path.stat().st_size >= total_size:
|
||||||
|
return
|
||||||
|
# Otherwise, resume the download from where it left off
|
||||||
|
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||||
|
mode = 'ab'
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
|
mode = 'wb'
|
||||||
|
|
||||||
|
r = requests.get(url, stream=True, headers=headers)
|
||||||
|
with open(output_path, mode) as f:
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
block_size = 1024
|
block_size = 1024
|
||||||
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||||
@ -154,7 +172,7 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
return links, sha256, is_lora
|
return links, sha256, is_lora
|
||||||
|
|
||||||
def download_files(file_list, output_folder, num_threads=8):
|
def download_files(file_list, output_folder, num_threads=8):
|
||||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads)
|
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = args.MODEL
|
model = args.MODEL
|
||||||
@ -184,22 +202,48 @@ if __name__ == '__main__':
|
|||||||
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
||||||
if branch != 'main':
|
if branch != 'main':
|
||||||
output_folder += f'_{branch}'
|
output_folder += f'_{branch}'
|
||||||
|
|
||||||
# Creating the folder and writing the metadata
|
|
||||||
output_folder = Path(base_folder) / output_folder
|
output_folder = Path(base_folder) / output_folder
|
||||||
if not output_folder.exists():
|
|
||||||
output_folder.mkdir()
|
|
||||||
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
|
||||||
f.write(f'url: https://huggingface.co/{model}\n')
|
|
||||||
f.write(f'branch: {branch}\n')
|
|
||||||
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
|
||||||
sha256_str = ''
|
|
||||||
for i in range(len(sha256)):
|
|
||||||
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
|
||||||
if sha256_str != '':
|
|
||||||
f.write(f'sha256sum:\n{sha256_str}')
|
|
||||||
|
|
||||||
# Downloading the files
|
if args.check:
|
||||||
print(f"Downloading the model to {output_folder}")
|
# Validate the checksums
|
||||||
download_files(links, output_folder, args.threads)
|
validated = True
|
||||||
print()
|
for i in range(len(sha256)):
|
||||||
|
fpath = (output_folder / sha256[i][0])
|
||||||
|
|
||||||
|
if not fpath.exists():
|
||||||
|
print(f"The following file is missing: {fpath}")
|
||||||
|
validated = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(output_folder / sha256[i][0], "rb") as f:
|
||||||
|
bytes = f.read()
|
||||||
|
file_hash = hashlib.sha256(bytes).hexdigest()
|
||||||
|
if file_hash != sha256[i][1]:
|
||||||
|
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
|
||||||
|
validated = False
|
||||||
|
else:
|
||||||
|
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
|
||||||
|
|
||||||
|
if validated:
|
||||||
|
print('[+] Validated checksums of all model files!')
|
||||||
|
else:
|
||||||
|
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
# Creating the folder and writing the metadata
|
||||||
|
if not output_folder.exists():
|
||||||
|
output_folder.mkdir()
|
||||||
|
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
||||||
|
f.write(f'url: https://huggingface.co/{model}\n')
|
||||||
|
f.write(f'branch: {branch}\n')
|
||||||
|
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
||||||
|
sha256_str = ''
|
||||||
|
for i in range(len(sha256)):
|
||||||
|
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
||||||
|
if sha256_str != '':
|
||||||
|
f.write(f'sha256sum:\n{sha256_str}')
|
||||||
|
|
||||||
|
# Downloading the files
|
||||||
|
print(f"Downloading the model to {output_folder}")
|
||||||
|
download_files(links, output_folder, args.threads)
|
Loading…
Reference in New Issue
Block a user