2023-09-22 00:35:53 +02:00
import argparse
import glob
2023-10-07 05:23:49 +02:00
import hashlib
2023-09-22 00:35:53 +02:00
import os
2023-09-24 14:58:29 +02:00
import platform
2023-09-26 15:56:57 +02:00
import re
2023-12-05 06:16:16 +01:00
import signal
2023-09-22 05:12:16 +02:00
import site
2023-09-23 15:48:09 +02:00
import subprocess
2023-09-22 00:35:53 +02:00
import sys
2023-09-22 17:02:21 +02:00
# Remove the '# ' from the following lines as needed for your AMD GPU on Linux
# os.environ["ROCM_PATH"] = '/opt/rocm'
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0'
# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030'
2024-03-03 23:40:32 +01:00
# Define the required PyTorch version
2024-09-28 18:44:08 +02:00
TORCH_VERSION = " 2.4.1 "
TORCHVISION_VERSION = " 0.19.1 "
TORCHAUDIO_VERSION = " 2.4.1 "
2024-03-03 23:40:32 +01:00
# Environment
script_dir = os . getcwd ( )
conda_env_path = os . path . join ( script_dir , " installer_files " , " env " )
2023-09-22 00:35:53 +02:00
# Command-line flags
2023-09-22 19:52:52 +02:00
cmd_flags_path = os . path . join ( script_dir , " CMD_FLAGS.txt " )
if os . path . exists ( cmd_flags_path ) :
with open ( cmd_flags_path , ' r ' ) as f :
2023-11-16 18:33:36 +01:00
CMD_FLAGS = ' ' . join ( line . strip ( ) . rstrip ( ' \\ ' ) . strip ( ) for line in f if line . strip ( ) . rstrip ( ' \\ ' ) . strip ( ) and not line . strip ( ) . startswith ( ' # ' ) )
2023-09-22 00:35:53 +02:00
else :
2023-09-22 19:52:52 +02:00
CMD_FLAGS = ' '
2023-09-22 00:35:53 +02:00
2024-03-04 19:52:24 +01:00
flags = f " { ' ' . join ( [ flag for flag in sys . argv [ 1 : ] if flag != ' --update-wizard ' ] ) } { CMD_FLAGS } "
2023-09-22 00:35:53 +02:00
2023-09-28 22:56:15 +02:00
2023-12-05 06:16:16 +01:00
def signal_handler ( sig , frame ) :
sys . exit ( 0 )
signal . signal ( signal . SIGINT , signal_handler )
2023-09-22 17:02:21 +02:00
def is_linux ( ) :
return sys . platform . startswith ( " linux " )
def is_windows ( ) :
return sys . platform . startswith ( " win " )
def is_macos ( ) :
return sys . platform . startswith ( " darwin " )
2023-09-24 14:58:29 +02:00
def is_x86_64 ( ) :
return platform . machine ( ) == " x86_64 "
2024-04-30 14:11:31 +02:00
def cpu_has_avx2 ( ) :
try :
import cpuinfo
info = cpuinfo . get_cpu_info ( )
if ' avx2 ' in info [ ' flags ' ] :
return True
else :
return False
except :
return True
def cpu_has_amx ( ) :
try :
import cpuinfo
info = cpuinfo . get_cpu_info ( )
if ' amx ' in info [ ' flags ' ] :
return True
else :
return False
except :
return True
2023-09-24 14:58:29 +02:00
def torch_version ( ) :
2023-09-28 13:31:29 +02:00
site_packages_path = None
2023-09-25 03:16:59 +02:00
for sitedir in site . getsitepackages ( ) :
if " site-packages " in sitedir and conda_env_path in sitedir :
site_packages_path = sitedir
break
if site_packages_path :
torch_version_file = open ( os . path . join ( site_packages_path , ' torch ' , ' version.py ' ) ) . read ( ) . splitlines ( )
2024-03-03 23:40:32 +01:00
torver = [ line for line in torch_version_file if line . startswith ( ' __version__ ' ) ] [ 0 ] . split ( ' __version__ = ' ) [ 1 ] . strip ( " ' " )
2023-09-25 03:16:59 +02:00
else :
from torch import __version__ as torver
2024-01-05 03:50:23 +01:00
2023-09-24 14:58:29 +02:00
return torver
2024-03-03 23:40:32 +01:00
def update_pytorch ( ) :
print_big_message ( " Checking for PyTorch updates " )
torver = torch_version ( )
is_cuda = ' +cu ' in torver
is_cuda118 = ' +cu118 ' in torver # 2.1.0+cu118
is_rocm = ' +rocm ' in torver # 2.0.1+rocm5.4.2
is_intel = ' +cxx11 ' in torver # 2.0.1a0+cxx11.abi
is_cpu = ' +cpu ' in torver # 2.0.1+cpu
install_pytorch = f " python -m pip install --upgrade torch== { TORCH_VERSION } torchvision== { TORCHVISION_VERSION } torchaudio== { TORCHAUDIO_VERSION } "
if is_cuda118 :
install_pytorch + = " --index-url https://download.pytorch.org/whl/cu118 "
elif is_cuda :
install_pytorch + = " --index-url https://download.pytorch.org/whl/cu121 "
elif is_rocm :
2024-09-28 18:39:37 +02:00
install_pytorch + = " --index-url https://download.pytorch.org/whl/rocm6.1 "
2024-03-03 23:40:32 +01:00
elif is_cpu :
install_pytorch + = " --index-url https://download.pytorch.org/whl/cpu "
elif is_intel :
if is_linux ( ) :
install_pytorch = " python -m pip install --upgrade torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ "
else :
install_pytorch = " python -m pip install --upgrade torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ "
run_cmd ( f " { install_pytorch } " , assert_success = True , environment = True )
2023-09-22 17:02:21 +02:00
def is_installed ( ) :
2023-09-28 13:31:29 +02:00
site_packages_path = None
2023-09-22 17:02:21 +02:00
for sitedir in site . getsitepackages ( ) :
if " site-packages " in sitedir and conda_env_path in sitedir :
site_packages_path = sitedir
break
if site_packages_path :
return os . path . isfile ( os . path . join ( site_packages_path , ' torch ' , ' __init__.py ' ) )
else :
return os . path . isdir ( conda_env_path )
def check_env ( ) :
# If we have access to conda, we are probably in an environment
conda_exist = run_cmd ( " conda " , environment = True , capture_output = True ) . returncode == 0
if not conda_exist :
print ( " Conda is not installed. Exiting... " )
2023-09-28 22:56:15 +02:00
sys . exit ( 1 )
2023-09-22 17:02:21 +02:00
# Ensure this is a new environment and not the base environment
if os . environ [ " CONDA_DEFAULT_ENV " ] == " base " :
print ( " Create an environment for this project and activate it. Exiting... " )
2023-09-28 22:56:15 +02:00
sys . exit ( 1 )
2023-09-22 17:02:21 +02:00
def clear_cache ( ) :
run_cmd ( " conda clean -a -y " , environment = True )
run_cmd ( " python -m pip cache purge " , environment = True )
2023-09-22 00:35:53 +02:00
def print_big_message ( message ) :
message = message . strip ( )
lines = message . split ( ' \n ' )
print ( " \n \n ******************************************************************* " )
for line in lines :
2024-03-04 17:20:04 +01:00
print ( " * " , line )
2023-09-22 00:35:53 +02:00
print ( " ******************************************************************* \n \n " )
2023-10-07 05:23:49 +02:00
def calculate_file_hash ( file_path ) :
p = os . path . join ( script_dir , file_path )
if os . path . isfile ( p ) :
with open ( p , ' rb ' ) as f :
return hashlib . sha256 ( f . read ( ) ) . hexdigest ( )
else :
return ' '
2023-09-22 00:35:53 +02:00
def run_cmd ( cmd , assert_success = False , environment = False , capture_output = False , env = None ) :
# Use the conda environment
if environment :
2023-09-22 17:02:21 +02:00
if is_windows ( ) :
2023-09-22 00:35:53 +02:00
conda_bat_path = os . path . join ( script_dir , " installer_files " , " conda " , " condabin " , " conda.bat " )
2024-01-27 21:31:22 +01:00
cmd = f ' " { conda_bat_path } " activate " { conda_env_path } " >nul && { cmd } '
2023-09-22 00:35:53 +02:00
else :
conda_sh_path = os . path . join ( script_dir , " installer_files " , " conda " , " etc " , " profile.d " , " conda.sh " )
2024-01-27 21:31:22 +01:00
cmd = f ' . " { conda_sh_path } " && conda activate " { conda_env_path } " && { cmd } '
2023-09-22 00:35:53 +02:00
2024-10-03 05:35:13 +02:00
# Set executable to None for Windows, bash for everything else
executable = None if is_windows ( ) else ' bash '
2024-09-29 05:55:26 +02:00
2023-09-22 00:35:53 +02:00
# Run shell commands
2024-09-29 05:55:26 +02:00
result = subprocess . run ( cmd , shell = True , capture_output = capture_output , env = env , executable = executable )
2023-09-22 00:35:53 +02:00
# Assert the command ran successfully
if assert_success and result . returncode != 0 :
2024-01-27 21:31:22 +01:00
print ( f " Command ' { cmd } ' failed with exit status code ' { str ( result . returncode ) } ' . \n \n Exiting now. \n Try running the start/update script again. " )
2023-09-28 22:56:15 +02:00
sys . exit ( 1 )
2023-09-22 00:35:53 +02:00
return result
2024-03-06 16:36:23 +01:00
def generate_alphabetic_sequence ( index ) :
result = ' '
while index > = 0 :
index , remainder = divmod ( index , 26 )
result = chr ( ord ( ' A ' ) + remainder ) + result
index - = 1
return result
2024-03-04 19:52:24 +01:00
def get_user_choice ( question , options_dict ) :
print ( )
print ( question )
print ( )
for key , value in options_dict . items ( ) :
print ( f " { key } ) { value } " )
print ( )
choice = input ( " Input> " ) . upper ( )
while choice not in options_dict . keys ( ) :
print ( " Invalid choice. Please try again. " )
choice = input ( " Input> " ) . upper ( )
return choice
2023-09-22 21:08:05 +02:00
def install_webui ( ) :
2024-03-03 23:40:32 +01:00
# Ask the user for the GPU vendor
2023-09-23 03:43:11 +02:00
if " GPU_CHOICE " in os . environ :
choice = os . environ [ " GPU_CHOICE " ] . upper ( )
print_big_message ( f " Selected GPU choice \" { choice } \" based on the GPU_CHOICE environment variable. " )
2025-01-09 20:58:33 +01:00
# Warn about changed meanings and handle old NVIDIA choice
if choice == " B " :
print_big_message ( " Warning: GPU_CHOICE= ' B ' now means ' NVIDIA (CUDA 11.8) ' in the new version. " )
elif choice == " C " :
print_big_message ( " Warning: GPU_CHOICE= ' C ' now means ' AMD ' in the new version. " )
elif choice == " D " :
print_big_message ( " Warning: GPU_CHOICE= ' D ' now means ' Apple M Series ' in the new version. " )
elif choice == " A " and " USE_CUDA118 " in os . environ :
choice = " B " if os . environ . get ( " USE_CUDA118 " , " " ) . lower ( ) in ( " yes " , " y " , " true " , " 1 " , " t " , " on " ) else " A "
2023-09-23 03:43:11 +02:00
else :
2024-03-04 19:52:24 +01:00
choice = get_user_choice (
" What is your GPU? " ,
{
2025-01-09 20:58:33 +01:00
' A ' : ' NVIDIA - CUDA 12.1 (recommended) ' ,
' B ' : ' NVIDIA - CUDA 11.8 (legacy GPUs) ' ,
' C ' : ' AMD - Linux/macOS only, requires ROCm 6.1 ' ,
' D ' : ' Apple M Series ' ,
' E ' : ' Intel Arc (beta) ' ,
' N ' : ' CPU mode '
2024-03-04 19:52:24 +01:00
} ,
)
2023-09-23 03:43:11 +02:00
2025-01-09 20:58:33 +01:00
# Convert choices to GPU names for compatibility
2024-01-05 03:41:54 +01:00
gpu_choice_to_name = {
" A " : " NVIDIA " ,
2025-01-09 20:58:33 +01:00
" B " : " NVIDIA " ,
" C " : " AMD " ,
" D " : " APPLE " ,
" E " : " INTEL " ,
2024-01-05 03:41:54 +01:00
" N " : " NONE "
}
selected_gpu = gpu_choice_to_name [ choice ]
2025-01-09 20:58:33 +01:00
use_cuda118 = ( choice == " B " ) # CUDA version is now determined by menu choice
2024-01-05 03:41:54 +01:00
2024-03-03 23:40:32 +01:00
# Write a flag to CMD_FLAGS.txt for CPU mode
2024-01-05 03:41:54 +01:00
if selected_gpu == " NONE " :
with open ( cmd_flags_path , ' r+ ' ) as cmd_flags_file :
if " --cpu " not in cmd_flags_file . read ( ) :
print_big_message ( " Adding the --cpu flag to CMD_FLAGS.txt. " )
2024-03-04 02:42:59 +01:00
cmd_flags_file . write ( " \n --cpu \n " )
2023-09-22 00:35:53 +02:00
2025-01-09 20:58:33 +01:00
# Handle CUDA version display
2024-03-03 23:40:32 +01:00
elif any ( ( is_windows ( ) , is_linux ( ) ) ) and selected_gpu == " NVIDIA " :
2025-01-09 20:58:33 +01:00
if use_cuda118 :
2023-10-22 17:37:24 +02:00
print ( " CUDA: 11.8 " )
else :
print ( " CUDA: 12.1 " )
2024-03-03 23:40:32 +01:00
# No PyTorch for AMD on Windows (?)
elif is_windows ( ) and selected_gpu == " AMD " :
print ( " PyTorch setup on Windows is not implemented yet. Exiting... " )
sys . exit ( 1 )
# Find the Pytorch installation command
install_pytorch = f " python -m pip install torch== { TORCH_VERSION } torchvision== { TORCHVISION_VERSION } torchaudio== { TORCHAUDIO_VERSION } "
if selected_gpu == " NVIDIA " :
if use_cuda118 == ' Y ' :
install_pytorch + = " --index-url https://download.pytorch.org/whl/cu118 "
2023-09-22 00:35:53 +02:00
else :
2024-03-03 23:40:32 +01:00
install_pytorch + = " --index-url https://download.pytorch.org/whl/cu121 "
elif selected_gpu == " AMD " :
2024-09-28 18:39:37 +02:00
install_pytorch + = " --index-url https://download.pytorch.org/whl/rocm6.1 "
2024-03-03 23:40:32 +01:00
elif selected_gpu in [ " APPLE " , " NONE " ] :
2024-01-05 03:50:23 +01:00
install_pytorch + = " --index-url https://download.pytorch.org/whl/cpu "
2024-01-05 03:41:54 +01:00
elif selected_gpu == " INTEL " :
2024-03-03 23:40:32 +01:00
if is_linux ( ) :
install_pytorch = " python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ "
else :
install_pytorch = " python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ "
2023-09-22 19:51:21 +02:00
# Install Git and then Pytorch
2023-12-15 01:41:59 +01:00
print_big_message ( " Installing PyTorch. " )
2024-04-30 14:11:31 +02:00
run_cmd ( f " conda install -y -k ninja git && { install_pytorch } && python -m pip install py-cpuinfo==9.0.0 " , assert_success = True , environment = True )
2023-10-13 06:02:44 +02:00
2024-01-07 14:40:30 +01:00
if selected_gpu == " INTEL " :
# Install oneAPI dependencies via conda
print_big_message ( " Installing Intel oneAPI runtime libraries. " )
2024-10-05 16:58:17 +02:00
run_cmd ( " conda install -y -c https://software.repos.intel.com/python/conda/ -c conda-forge dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0 " )
2024-01-07 14:40:30 +01:00
# Install libuv required by Intel-patched torch
2024-01-07 18:30:55 +01:00
run_cmd ( " conda install -y libuv " )
2024-01-07 14:40:30 +01:00
2023-09-22 21:08:05 +02:00
# Install the webui requirements
2024-06-15 15:38:05 +02:00
update_requirements ( initial_installation = True , pull = False )
2023-09-22 00:35:53 +02:00
2024-03-06 16:36:23 +01:00
def get_extensions_names ( ) :
return [ foldername for foldername in os . listdir ( ' extensions ' ) if os . path . isfile ( os . path . join ( ' extensions ' , foldername , ' requirements.txt ' ) ) ]
2024-03-04 08:46:39 +01:00
def install_extensions_requirements ( ) :
2024-03-05 05:37:44 +01:00
print_big_message ( " Installing extensions requirements. \n Some of these may fail on Windows. \n Don \' t worry if you see error messages, as they will not affect the main program. " )
2024-03-06 16:36:23 +01:00
extensions = get_extensions_names ( )
2024-03-04 08:46:39 +01:00
for i , extension in enumerate ( extensions ) :
2024-10-05 16:58:17 +02:00
print ( f " \n \n --- [ { i + 1 } / { len ( extensions ) } ]: { extension } \n \n " )
2024-03-04 08:46:39 +01:00
extension_req_path = os . path . join ( " extensions " , extension , " requirements.txt " )
run_cmd ( f " python -m pip install -r { extension_req_path } --upgrade " , assert_success = False , environment = True )
2024-03-04 21:35:41 +01:00
def update_requirements ( initial_installation = False , pull = True ) :
2023-09-22 04:51:58 +02:00
# Create .git directory if missing
2024-01-27 21:18:50 +01:00
if not os . path . exists ( os . path . join ( script_dir , " .git " ) ) :
2023-10-21 08:13:09 +02:00
git_creation_cmd = ' git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git symbolic-ref refs/remotes/origin/HEAD refs/remotes/origin/main && git reset --hard origin/main && git branch --set-upstream-to=origin/main '
2023-09-22 04:51:58 +02:00
run_cmd ( git_creation_cmd , environment = True , assert_success = True )
2023-09-22 17:02:21 +02:00
2024-03-04 21:35:41 +01:00
if pull :
print_big_message ( " Updating the local copy of the repository with \" git pull \" " )
2023-10-07 05:23:49 +02:00
2024-03-04 21:35:41 +01:00
files_to_check = [
' start_linux.sh ' , ' start_macos.sh ' , ' start_windows.bat ' , ' start_wsl.bat ' ,
2024-03-07 06:13:54 +01:00
' update_wizard_linux.sh ' , ' update_wizard_macos.sh ' , ' update_wizard_windows.bat ' , ' update_wizard_wsl.bat ' ,
2024-03-04 21:35:41 +01:00
' one_click.py '
]
2024-03-04 17:00:39 +01:00
2024-03-04 21:35:41 +01:00
before_pull_hashes = { file_name : calculate_file_hash ( file_name ) for file_name in files_to_check }
2024-07-25 16:34:01 +02:00
run_cmd ( " git pull --autostash " , assert_success = True , environment = True )
2024-03-04 21:35:41 +01:00
after_pull_hashes = { file_name : calculate_file_hash ( file_name ) for file_name in files_to_check }
2023-10-07 05:23:49 +02:00
2024-03-04 21:35:41 +01:00
# Check for differences in installation file hashes
for file_name in files_to_check :
if before_pull_hashes [ file_name ] != after_pull_hashes [ file_name ] :
print_big_message ( f " File ' { file_name } ' was updated during ' git pull ' . Please run the script again. " )
exit ( 1 )
2023-09-22 00:35:53 +02:00
2024-03-06 20:31:06 +01:00
if os . environ . get ( " INSTALL_EXTENSIONS " , " " ) . lower ( ) in ( " yes " , " y " , " true " , " 1 " , " t " , " on " ) :
install_extensions_requirements ( )
2024-03-03 23:40:32 +01:00
# Update PyTorch
if not initial_installation :
update_pytorch ( )
# Detect the PyTorch version
2023-09-24 14:58:29 +02:00
torver = torch_version ( )
2023-10-21 08:46:23 +02:00
is_cuda = ' +cu ' in torver
is_cuda118 = ' +cu118 ' in torver # 2.1.0+cu118
2023-09-24 14:58:29 +02:00
is_rocm = ' +rocm ' in torver # 2.0.1+rocm5.4.2
2024-01-05 03:51:52 +01:00
is_intel = ' +cxx11 ' in torver # 2.0.1a0+cxx11.abi
2023-09-24 14:58:29 +02:00
is_cpu = ' +cpu ' in torver # 2.0.1+cpu
if is_rocm :
2024-04-30 14:11:31 +02:00
base_requirements = " requirements_amd " + ( " _noavx2 " if not cpu_has_avx2 ( ) else " " ) + " .txt "
2024-01-05 03:51:52 +01:00
elif is_cpu or is_intel :
2024-04-30 14:11:31 +02:00
base_requirements = " requirements_cpu_only " + ( " _noavx2 " if not cpu_has_avx2 ( ) else " " ) + " .txt "
2023-09-24 14:58:29 +02:00
elif is_macos ( ) :
2024-01-05 03:41:54 +01:00
base_requirements = " requirements_apple_ " + ( " intel " if is_x86_64 ( ) else " silicon " ) + " .txt "
2023-09-24 14:58:29 +02:00
else :
2024-04-30 14:11:31 +02:00
base_requirements = " requirements " + ( " _noavx2 " if not cpu_has_avx2 ( ) else " " ) + " .txt "
2024-01-05 03:41:54 +01:00
requirements_file = base_requirements
2023-09-24 14:58:29 +02:00
2024-07-25 16:34:01 +02:00
print_big_message ( f " Installing webui requirements from file: { requirements_file } " )
print ( f " TORCH: { torver } \n " )
# Prepare the requirements file
textgen_requirements = open ( requirements_file ) . read ( ) . splitlines ( )
if is_cuda118 :
textgen_requirements = [
req . replace ( ' +cu121 ' , ' +cu118 ' ) . replace ( ' +cu122 ' , ' +cu118 ' )
for req in textgen_requirements
2025-01-08 23:28:56 +01:00
if " autoawq " not in req . lower ( )
2024-07-25 16:34:01 +02:00
]
if is_windows ( ) and is_cuda118 : # No flash-attention on Windows for CUDA 11
textgen_requirements = [ req for req in textgen_requirements if ' oobabooga/flash-attention ' not in req ]
with open ( ' temp_requirements.txt ' , ' w ' ) as file :
file . write ( ' \n ' . join ( textgen_requirements ) )
# Workaround for git+ packages not updating properly.
git_requirements = [ req for req in textgen_requirements if req . startswith ( " git+ " ) ]
for req in git_requirements :
url = req . replace ( " git+ " , " " )
package_name = url . split ( " / " ) [ - 1 ] . split ( " @ " ) [ 0 ] . rstrip ( " .git " )
run_cmd ( f " python -m pip uninstall -y { package_name } " , environment = True )
print ( f " Uninstalled { package_name } " )
# Install/update the project requirements
run_cmd ( " python -m pip install -r temp_requirements.txt --upgrade " , assert_success = True , environment = True )
os . remove ( ' temp_requirements.txt ' )
2023-09-22 00:35:53 +02:00
2023-09-22 21:08:05 +02:00
# Check for '+cu' or '+rocm' in version string to determine if torch uses CUDA or ROCm. Check for pytorch-cuda as well for backwards compatibility
2023-09-23 15:28:58 +02:00
if not any ( ( is_cuda , is_rocm ) ) and run_cmd ( " conda list -f pytorch-cuda | grep pytorch-cuda " , environment = True , capture_output = True ) . returncode == 1 :
2023-09-22 00:35:53 +02:00
clear_cache ( )
return
2024-07-25 16:34:01 +02:00
if not os . path . exists ( " repositories/ " ) :
os . mkdir ( " repositories " )
2023-09-22 00:35:53 +02:00
clear_cache ( )
def launch_webui ( ) :
2023-09-26 15:56:57 +02:00
run_cmd ( f " python server.py { flags } " , environment = True )
2023-09-22 00:35:53 +02:00
if __name__ == " __main__ " :
# Verifies we are in a conda environment
check_env ( )
2023-09-23 16:27:27 +02:00
parser = argparse . ArgumentParser ( add_help = False )
2024-03-04 19:52:24 +01:00
parser . add_argument ( ' --update-wizard ' , action = ' store_true ' , help = ' Launch a menu with update options. ' )
2023-09-22 19:03:56 +02:00
args , _ = parser . parse_known_args ( )
2023-09-22 00:35:53 +02:00
2024-03-04 19:52:24 +01:00
if args . update_wizard :
2024-03-06 16:36:23 +01:00
while True :
choice = get_user_choice (
" What would you like to do? " ,
{
' A ' : ' Update the web UI ' ,
' B ' : ' Install/update extensions requirements ' ,
' C ' : ' Revert local changes to repository files with \" git reset --hard \" ' ,
' N ' : ' Nothing (exit) '
} ,
)
if choice == ' A ' :
update_requirements ( )
elif choice == ' B ' :
choices = { ' A ' : ' All extensions ' }
for i , name in enumerate ( get_extensions_names ( ) ) :
key = generate_alphabetic_sequence ( i + 1 )
choices [ key ] = name
choice = get_user_choice ( " What extension? " , choices )
if choice == ' A ' :
install_extensions_requirements ( )
else :
extension_req_path = os . path . join ( " extensions " , choices [ choice ] , " requirements.txt " )
run_cmd ( f " python -m pip install -r { extension_req_path } --upgrade " , assert_success = False , environment = True )
update_requirements ( pull = False )
elif choice == ' C ' :
run_cmd ( " git reset --hard " , assert_success = True , environment = True )
elif choice == ' N ' :
sys . exit ( )
2023-09-22 00:35:53 +02:00
else :
2023-09-22 05:12:16 +02:00
if not is_installed ( ) :
2023-09-22 21:08:05 +02:00
install_webui ( )
2023-09-22 00:35:53 +02:00
os . chdir ( script_dir )
2023-09-23 03:43:11 +02:00
if os . environ . get ( " LAUNCH_AFTER_INSTALL " , " " ) . lower ( ) in ( " no " , " n " , " false " , " 0 " , " f " , " off " ) :
2024-03-04 08:46:39 +01:00
print_big_message ( " Will now exit due to LAUNCH_AFTER_INSTALL. " )
2023-09-23 03:43:11 +02:00
sys . exit ( )
2023-09-22 00:35:53 +02:00
# Check if a model has been downloaded yet
2023-09-26 15:56:57 +02:00
if ' --model-dir ' in flags :
# Splits on ' ' or '=' while maintaining spaces within quotes
2023-09-28 22:56:15 +02:00
flags_list = re . split ( ' +(?=(?:[^ \" ]* \" [^ \" ]* \" )*[^ \" ]*$)|= ' , flags )
2024-01-05 03:41:54 +01:00
model_dir = [ flags_list [ ( flags_list . index ( flag ) + 1 ) ] for flag in flags_list if flag == ' --model-dir ' ] [ 0 ] . strip ( ' " \' ' )
2023-09-26 15:56:57 +02:00
else :
model_dir = ' models '
if len ( [ item for item in glob . glob ( f ' { model_dir } /* ' ) if not item . endswith ( ( ' .txt ' , ' .yaml ' ) ) ] ) == 0 :
2024-03-04 08:46:39 +01:00
print_big_message ( " You haven ' t downloaded any model yet. \n Once the web UI launches, head over to the \" Model \" tab and download one. " )
2023-09-22 00:35:53 +02:00
# Workaround for llama-cpp-python loading paths in CUDA env vars even if they do not exist
conda_path_bin = os . path . join ( conda_env_path , " bin " )
if not os . path . exists ( conda_path_bin ) :
os . mkdir ( conda_path_bin )
# Launch the webui
launch_webui ( )