From 7c9664ed35e898e51cfbc45d223a6f78363a3932 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 16 Sep 2023 08:06:13 -0500 Subject: [PATCH] Allow full model URL to be used for download (#3919) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- download-model.py | 11 ++++++++++- modules/ui_model_menu.py | 8 ++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/download-model.py b/download-model.py index ba4d3bc7..d9b21d3a 100644 --- a/download-model.py +++ b/download-model.py @@ -22,6 +22,9 @@ from requests.adapters import HTTPAdapter from tqdm.contrib.concurrent import thread_map +base = "https://huggingface.co" + + class ModelDownloader: def __init__(self, max_retries=5): self.session = requests.Session() @@ -37,6 +40,13 @@ class ModelDownloader: if model[-1] == '/': model = model[:-1] + if model.startswith(base + '/'): + model = model[len(base) + 1:] + + model_parts = model.split(":") + model = model_parts[0] if len(model_parts) > 0 else model + branch = model_parts[1] if len(model_parts) > 1 else branch + if branch is None: branch = "main" else: @@ -48,7 +58,6 @@ class ModelDownloader: return model, branch def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None): - base = "https://huggingface.co" page = f"/api/models/{model}/tree/{branch}" cursor = b"" diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 0a063c39..dcb8778f 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -216,18 +216,14 @@ def load_lora_wrapper(selected_loras): yield ("Successfuly applied the LoRAs") -def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False): +def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False): try: downloader_module = importlib.import_module("download-model") downloader = downloader_module.ModelDownloader() - repo_id_parts = repo_id.split(":") - model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id - branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main" - check = False progress(0.0) yield ("Cleaning up the model/branch names") - model, branch = downloader.sanitize_model_and_branch_names(model, branch) + model, branch = downloader.sanitize_model_and_branch_names(repo_id, None) yield ("Getting the download links from Hugging Face") links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file)