mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-28 18:48:04 +01:00
Downloader: start one session per file (#5520)
This commit is contained in:
parent
44018c2f69
commit
f465b7b486
@ -26,13 +26,16 @@ base = "https://huggingface.co"
|
|||||||
|
|
||||||
class ModelDownloader:
|
class ModelDownloader:
|
||||||
def __init__(self, max_retries=5):
|
def __init__(self, max_retries=5):
|
||||||
self.session = requests.Session()
|
self.max_retries = max_retries
|
||||||
if max_retries:
|
|
||||||
self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
|
def get_session(self):
|
||||||
self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
|
session = requests.Session()
|
||||||
|
if self.max_retries:
|
||||||
|
session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=self.max_retries))
|
||||||
|
session.mount('https://huggingface.co', HTTPAdapter(max_retries=self.max_retries))
|
||||||
|
|
||||||
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
|
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
|
||||||
self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
|
session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import get_token
|
from huggingface_hub import get_token
|
||||||
@ -41,7 +44,9 @@ class ModelDownloader:
|
|||||||
token = os.getenv("HF_TOKEN")
|
token = os.getenv("HF_TOKEN")
|
||||||
|
|
||||||
if token is not None:
|
if token is not None:
|
||||||
self.session.headers = {'authorization': f'Bearer {token}'}
|
session.headers = {'authorization': f'Bearer {token}'}
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
def sanitize_model_and_branch_names(self, model, branch):
|
def sanitize_model_and_branch_names(self, model, branch):
|
||||||
if model[-1] == '/':
|
if model[-1] == '/':
|
||||||
@ -65,6 +70,7 @@ class ModelDownloader:
|
|||||||
return model, branch
|
return model, branch
|
||||||
|
|
||||||
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
|
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
|
||||||
|
session = self.get_session()
|
||||||
page = f"/api/models/{model}/tree/{branch}"
|
page = f"/api/models/{model}/tree/{branch}"
|
||||||
cursor = b""
|
cursor = b""
|
||||||
|
|
||||||
@ -78,7 +84,7 @@ class ModelDownloader:
|
|||||||
is_lora = False
|
is_lora = False
|
||||||
while True:
|
while True:
|
||||||
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
|
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
|
||||||
r = self.session.get(url, timeout=10)
|
r = session.get(url, timeout=10)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
content = r.content
|
content = r.content
|
||||||
|
|
||||||
@ -171,6 +177,7 @@ class ModelDownloader:
|
|||||||
return output_folder
|
return output_folder
|
||||||
|
|
||||||
def get_single_file(self, url, output_folder, start_from_scratch=False):
|
def get_single_file(self, url, output_folder, start_from_scratch=False):
|
||||||
|
session = self.get_session()
|
||||||
filename = Path(url.rsplit('/', 1)[1])
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
output_path = output_folder / filename
|
output_path = output_folder / filename
|
||||||
headers = {}
|
headers = {}
|
||||||
@ -178,7 +185,7 @@ class ModelDownloader:
|
|||||||
if output_path.exists() and not start_from_scratch:
|
if output_path.exists() and not start_from_scratch:
|
||||||
|
|
||||||
# Check if the file has already been downloaded completely
|
# Check if the file has already been downloaded completely
|
||||||
r = self.session.get(url, stream=True, timeout=10)
|
r = session.get(url, stream=True, timeout=10)
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
if output_path.stat().st_size >= total_size:
|
if output_path.stat().st_size >= total_size:
|
||||||
return
|
return
|
||||||
@ -187,7 +194,7 @@ class ModelDownloader:
|
|||||||
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||||
mode = 'ab'
|
mode = 'ab'
|
||||||
|
|
||||||
with self.session.get(url, stream=True, headers=headers, timeout=10) as r:
|
with session.get(url, stream=True, headers=headers, timeout=10) as r:
|
||||||
r.raise_for_status() # Do not continue the download if the request was unsuccessful
|
r.raise_for_status() # Do not continue the download if the request was unsuccessful
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
block_size = 1024 * 1024 # 1MB
|
block_size = 1024 * 1024 # 1MB
|
||||||
|
Loading…
Reference in New Issue
Block a user