mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-26 22:30:44 +01:00
Add support for resuming downloads
This commit adds the ability to resume interrupted downloads by adding a new function to the downloader module. The function uses the HTTP Range header to fetch only the remaining part of a file that wasn't downloaded yet.
This commit is contained in:
parent
f0fdab08d3
commit
e17af59261
@ -27,8 +27,23 @@ parser.add_argument('--output', type=str, default=None, help='The folder where t
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def get_file(url, output_folder):
|
def get_file(url, output_folder):
|
||||||
r = requests.get(url, stream=True)
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
|
output_path = output_folder / filename
|
||||||
|
if output_path.exists():
|
||||||
|
# Check if the file has already been downloaded completely
|
||||||
|
r = requests.head(url)
|
||||||
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
|
if output_path.stat().st_size == total_size:
|
||||||
|
return
|
||||||
|
# Otherwise, resume the download from where it left off
|
||||||
|
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||||
|
mode = 'ab'
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
|
mode = 'wb'
|
||||||
|
|
||||||
|
r = requests.get(url, stream=True, headers=headers)
|
||||||
|
with open(output_path, mode) as f:
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
block_size = 1024
|
block_size = 1024
|
||||||
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||||
@ -149,7 +164,7 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
return links, sha256, is_lora
|
return links, sha256, is_lora
|
||||||
|
|
||||||
def download_files(file_list, output_folder, num_threads=8):
|
def download_files(file_list, output_folder, num_threads=8):
|
||||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads)
|
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = args.MODEL
|
model = args.MODEL
|
||||||
|
Loading…
Reference in New Issue
Block a user