From 862b45b1c798cd3a16d2a95939d87cd4068a53ad Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:48:30 -0700 Subject: [PATCH] Extension install improvements --- modules/github.py | 22 +++++++++++++++------- modules/utils.py | 6 ++++-- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/modules/github.py b/modules/github.py index 454e9d23..d68ca847 100644 --- a/modules/github.py +++ b/modules/github.py @@ -1,19 +1,25 @@ -import os import subprocess +from pathlib import Path + +from modules.logging_colors import logger + +new_extensions = set() def clone_or_pull_repository(github_url): - repository_folder = "extensions" - repo_name = github_url.split("/")[-1].split(".")[0] + global new_extensions + + repository_folder = Path("extensions") + repo_name = github_url.rstrip("/").split("/")[-1].split(".")[0] # Check if the repository folder exists - if not os.path.exists(repository_folder): - os.makedirs(repository_folder) + if not repository_folder.exists(): + repository_folder.mkdir(parents=True) - repo_path = os.path.join(repository_folder, repo_name) + repo_path = repository_folder / repo_name # Check if the repository is already cloned - if os.path.exists(repo_path): + if repo_path.exists(): yield f"Updating {github_url}..." # Perform a 'git pull' to update the repository try: @@ -27,6 +33,8 @@ def clone_or_pull_repository(github_url): try: yield f"Cloning {github_url}..." clone_output = subprocess.check_output(["git", "clone", github_url, repo_path], stderr=subprocess.STDOUT) + new_extensions.add(repo_name) + logger.info(f"The extension {repo_name} has been downloaded. Please close the the web UI and launch it again to be able to load it.") yield "Done." return clone_output.decode() except subprocess.CalledProcessError as e: diff --git a/modules/utils.py b/modules/utils.py index e6449052..e4eef224 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -3,7 +3,7 @@ import re from datetime import datetime from pathlib import Path -from modules import shared +from modules import github, shared from modules.logging_colors import logger @@ -107,7 +107,9 @@ def get_available_instruction_templates(): def get_available_extensions(): - return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys) + extensions = sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys) + extensions = [v for v in extensions if v not in github.new_extensions] + return extensions def get_available_loras():