mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 17:06:53 +01:00
Allow full model URL to be used for download (#3919)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
ed6b6411fb
commit
7c9664ed35
@ -22,6 +22,9 @@ from requests.adapters import HTTPAdapter
|
|||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
|
|
||||||
|
base = "https://huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
class ModelDownloader:
|
class ModelDownloader:
|
||||||
def __init__(self, max_retries=5):
|
def __init__(self, max_retries=5):
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
@ -37,6 +40,13 @@ class ModelDownloader:
|
|||||||
if model[-1] == '/':
|
if model[-1] == '/':
|
||||||
model = model[:-1]
|
model = model[:-1]
|
||||||
|
|
||||||
|
if model.startswith(base + '/'):
|
||||||
|
model = model[len(base) + 1:]
|
||||||
|
|
||||||
|
model_parts = model.split(":")
|
||||||
|
model = model_parts[0] if len(model_parts) > 0 else model
|
||||||
|
branch = model_parts[1] if len(model_parts) > 1 else branch
|
||||||
|
|
||||||
if branch is None:
|
if branch is None:
|
||||||
branch = "main"
|
branch = "main"
|
||||||
else:
|
else:
|
||||||
@ -48,7 +58,6 @@ class ModelDownloader:
|
|||||||
return model, branch
|
return model, branch
|
||||||
|
|
||||||
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
|
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
|
||||||
base = "https://huggingface.co"
|
|
||||||
page = f"/api/models/{model}/tree/{branch}"
|
page = f"/api/models/{model}/tree/{branch}"
|
||||||
cursor = b""
|
cursor = b""
|
||||||
|
|
||||||
|
@ -216,18 +216,14 @@ def load_lora_wrapper(selected_loras):
|
|||||||
yield ("Successfuly applied the LoRAs")
|
yield ("Successfuly applied the LoRAs")
|
||||||
|
|
||||||
|
|
||||||
def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False):
|
def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False):
|
||||||
try:
|
try:
|
||||||
downloader_module = importlib.import_module("download-model")
|
downloader_module = importlib.import_module("download-model")
|
||||||
downloader = downloader_module.ModelDownloader()
|
downloader = downloader_module.ModelDownloader()
|
||||||
repo_id_parts = repo_id.split(":")
|
|
||||||
model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id
|
|
||||||
branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
|
|
||||||
check = False
|
|
||||||
|
|
||||||
progress(0.0)
|
progress(0.0)
|
||||||
yield ("Cleaning up the model/branch names")
|
yield ("Cleaning up the model/branch names")
|
||||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
model, branch = downloader.sanitize_model_and_branch_names(repo_id, None)
|
||||||
|
|
||||||
yield ("Getting the download links from Hugging Face")
|
yield ("Getting the download links from Hugging Face")
|
||||||
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file)
|
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user