Extension install improvements

This commit is contained in:
oobabooga 2023-09-25 19:48:30 -07:00
parent 44438c60e5
commit 862b45b1c7
2 changed files with 19 additions and 9 deletions

View File

@ -1,19 +1,25 @@
import os
import subprocess import subprocess
from pathlib import Path
from modules.logging_colors import logger
new_extensions = set()
def clone_or_pull_repository(github_url): def clone_or_pull_repository(github_url):
repository_folder = "extensions" global new_extensions
repo_name = github_url.split("/")[-1].split(".")[0]
repository_folder = Path("extensions")
repo_name = github_url.rstrip("/").split("/")[-1].split(".")[0]
# Check if the repository folder exists # Check if the repository folder exists
if not os.path.exists(repository_folder): if not repository_folder.exists():
os.makedirs(repository_folder) repository_folder.mkdir(parents=True)
repo_path = os.path.join(repository_folder, repo_name) repo_path = repository_folder / repo_name
# Check if the repository is already cloned # Check if the repository is already cloned
if os.path.exists(repo_path): if repo_path.exists():
yield f"Updating {github_url}..." yield f"Updating {github_url}..."
# Perform a 'git pull' to update the repository # Perform a 'git pull' to update the repository
try: try:
@ -27,6 +33,8 @@ def clone_or_pull_repository(github_url):
try: try:
yield f"Cloning {github_url}..." yield f"Cloning {github_url}..."
clone_output = subprocess.check_output(["git", "clone", github_url, repo_path], stderr=subprocess.STDOUT) clone_output = subprocess.check_output(["git", "clone", github_url, repo_path], stderr=subprocess.STDOUT)
new_extensions.add(repo_name)
logger.info(f"The extension {repo_name} has been downloaded. Please close the the web UI and launch it again to be able to load it.")
yield "Done." yield "Done."
return clone_output.decode() return clone_output.decode()
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:

View File

@ -3,7 +3,7 @@ import re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from modules import shared from modules import github, shared
from modules.logging_colors import logger from modules.logging_colors import logger
@ -107,7 +107,9 @@ def get_available_instruction_templates():
def get_available_extensions(): def get_available_extensions():
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys) extensions = sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
extensions = [v for v in extensions if v not in github.new_extensions]
return extensions
def get_available_loras(): def get_available_loras():