From 85f45cafa1546f8b17956bd5c08eda9e977f7bb6 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 28 Sep 2023 13:54:36 -0700 Subject: [PATCH] Fix extensions install --- one_click.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/one_click.py b/one_click.py index 0aae427b..3a23246f 100644 --- a/one_click.py +++ b/one_click.py @@ -190,19 +190,18 @@ def update_requirements(initial_installation=False): # Extensions requirements are installed only during the initial install by default. # That can be changed with the INSTALL_EXTENSIONS environment variable. install_extensions = os.environ.get("INSTALL_EXTENSIONS", "false").lower() in ("yes", "y", "true", "1", "t", "on") - if initial_installation or install_extensions: - if not install_extensions: - print_big_message("Will not install extensions due to INSTALL_EXTENSIONS environment variable.") - else: - print("Installing extensions requirements.") - extensions = next(os.walk("extensions"))[1] - for extension in extensions: - if extension in ['superbooga', 'superboogav2']: # No wheels available for requirements - continue + if initial_installation and not install_extensions: + print_big_message("Will not install extensions due to INSTALL_EXTENSIONS environment variable.") + elif initial_installation or install_extensions: + print("Installing extensions requirements.") + extensions = next(os.walk("extensions"))[1] + for extension in extensions: + if extension in ['superbooga', 'superboogav2']: # No wheels available for requirements + continue - extension_req_path = os.path.join("extensions", extension, "requirements.txt") - if os.path.exists(extension_req_path): - run_cmd("python -m pip install -r " + extension_req_path + " --upgrade", assert_success=True, environment=True) + extension_req_path = os.path.join("extensions", extension, "requirements.txt") + if os.path.exists(extension_req_path): + run_cmd("python -m pip install -r " + extension_req_path + " --upgrade", assert_success=True, environment=True) # Detect the PyTorch version torver = torch_version()