diff --git a/webui.py b/webui.py index d5071c80..aaa77885 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,11 @@ import sys script_dir = os.getcwd() conda_env_path = os.path.join(script_dir, "installer_files", "env") +# Remove the '# ' from the following lines as needed for your AMD GPU on Linux +# os.environ["ROCM_PATH"] = '/opt/rocm' +# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0' +# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030' + # Command-line flags if "OOBABOOGA_FLAGS" in os.environ: CMD_FLAGS = os.environ["OOBABOOGA_FLAGS"] @@ -23,42 +28,28 @@ else: CMD_FLAGS = '' -# Remove the '# ' from the following lines as needed for your AMD GPU on Linux -# os.environ["ROCM_PATH"] = '/opt/rocm' -# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0' -# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030' +def is_linux(): + return sys.platform.startswith("linux") -def print_big_message(message): - message = message.strip() - lines = message.split('\n') - print("\n\n*******************************************************************") - for line in lines: - if line.strip() != '': - print("*", line) - - print("*******************************************************************\n\n") +def is_windows(): + return sys.platform.startswith("win") -def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None): - # Use the conda environment - if environment: - if sys.platform.startswith("win"): - conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat") - cmd = "\"" + conda_bat_path + "\" activate \"" + conda_env_path + "\" >nul && " + cmd - else: - conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh") - cmd = ". \"" + conda_sh_path + "\" && conda activate \"" + conda_env_path + "\" && " + cmd +def is_macos(): + return sys.platform.startswith("darwin") - # Run shell commands - result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env) - # Assert the command ran successfully - if assert_success and result.returncode != 0: - print("Command '" + cmd + "' failed with exit status code '" + str(result.returncode) + "'. Exiting...") - sys.exit() +def is_installed(): + for sitedir in site.getsitepackages(): + if "site-packages" in sitedir and conda_env_path in sitedir: + site_packages_path = sitedir + break - return result + if site_packages_path: + return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py')) + else: + return os.path.isdir(conda_env_path) def check_env(): @@ -78,19 +69,41 @@ def clear_cache(): run_cmd("conda clean -a -y", environment=True) run_cmd("python -m pip cache purge", environment=True) -def is_installed(): - for sitedir in site.getsitepackages(): - if "site-packages" in sitedir and conda_env_path in sitedir: - site_packages_path = sitedir - break - if site_packages_path: - return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py')) - else: - return os.path.isdir(conda_env_path) +def print_big_message(message): + message = message.strip() + lines = message.split('\n') + print("\n\n*******************************************************************") + for line in lines: + if line.strip() != '': + print("*", line) + + print("*******************************************************************\n\n") + + +def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None): + # Use the conda environment + if environment: + if is_windows(): + conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat") + cmd = "\"" + conda_bat_path + "\" activate \"" + conda_env_path + "\" >nul && " + cmd + else: + conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh") + cmd = ". \"" + conda_sh_path + "\" && conda activate \"" + conda_env_path + "\" && " + cmd + + # Run shell commands + result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env) + + # Assert the command ran successfully + if assert_success and result.returncode != 0: + print("Command '" + cmd + "' failed with exit status code '" + str(result.returncode) + "'. Exiting...") + sys.exit() + + return result + def install_dependencies(): - # Select your GPU or, choose to run in CPU mode + print("What is your GPU") print() print("A) NVIDIA") @@ -98,27 +111,28 @@ def install_dependencies(): print("C) Apple M Series") print("D) None (I want to run models in CPU mode)") print() + + # Select your GPU, or choose to run in CPU mode gpuchoice = input("Input> ").lower() while gpuchoice not in ['a', 'b', 'c', 'd']: print("Invalid choice. Please try again.") gpuchoice = input("Input> ").lower() - if gpuchoice == "d": print_big_message("Once the installation ends, make sure to open CMD_FLAGS.txt with\na text editor and add the --cpu flag.") - # Install the version of PyTorch needed + # Install Pytorch if gpuchoice == "a": run_cmd('conda install -y -k cuda ninja git -c nvidia/label/cuda-11.7.0 -c nvidia && python -m pip install torch==2.0.1+cu117 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117', assert_success=True, environment=True) - elif gpuchoice == "b" and not sys.platform.startswith("darwin"): - if sys.platform.startswith("linux"): + elif gpuchoice == "b" and not is_macos(): + if is_linux(): run_cmd('conda install -y -k ninja git && python -m pip install torch==2.0.1+rocm5.4.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2', assert_success=True, environment=True) else: print("AMD GPUs are only supported on Linux. Exiting...") sys.exit() - elif (gpuchoice == "c" or gpuchoice == "b") and sys.platform.startswith("darwin"): + elif (gpuchoice == "c" or gpuchoice == "b") and is_macos(): run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True) elif gpuchoice == "d" or gpuchoice == "c": - if sys.platform.startswith("linux"): + if is_linux(): run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu", assert_success=True, environment=True) else: run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True) @@ -132,8 +146,8 @@ def update_dependencies(initial_installation=False): if not os.path.isdir(os.path.join(script_dir, ".git")): git_creation_cmd = 'git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git remote set-head origin -a && git reset origin/HEAD && git branch --set-upstream-to=origin/HEAD' run_cmd(git_creation_cmd, environment=True, assert_success=True) - - run_cmd("git pull --autostash", assert_success=True, environment=True) # TODO is there a better way? + + run_cmd("git pull --autostash", assert_success=True, environment=True) # Install the extensions dependencies (only on the first install) if initial_installation: @@ -196,11 +210,11 @@ def update_dependencies(initial_installation=False): run_cmd("python -m pip install " + exllama_rocm, environment=True) # Fix JIT compile issue with exllama in Linux/WSL - if sys.platform.startswith("linux") and not os.path.exists(f"{conda_env_path}/lib64"): + if is_linux() and not os.path.exists(f"{conda_env_path}/lib64"): run_cmd(f'ln -s "{conda_env_path}/lib" "{conda_env_path}/lib64"', environment=True) # On some Linux distributions, g++ may not exist or be the wrong version to compile GPTQ-for-LLaMa - if sys.platform.startswith("linux"): + if is_linux(): gxx_output = run_cmd("g++ -dumpfullversion -dumpversion", environment=True, capture_output=True) if gxx_output.returncode != 0 or int(gxx_output.stdout.strip().split(b".")[0]) > 11: # Install the correct version of g++