Minor changes / reorder some functions

This commit is contained in:
oobabooga 2023-09-22 08:02:21 -07:00
parent 84b5a519cb
commit 66363a4d70

112
webui.py
View File

@ -8,6 +8,11 @@ import sys
script_dir = os.getcwd() script_dir = os.getcwd()
conda_env_path = os.path.join(script_dir, "installer_files", "env") conda_env_path = os.path.join(script_dir, "installer_files", "env")
# Remove the '# ' from the following lines as needed for your AMD GPU on Linux
# os.environ["ROCM_PATH"] = '/opt/rocm'
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0'
# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030'
# Command-line flags # Command-line flags
if "OOBABOOGA_FLAGS" in os.environ: if "OOBABOOGA_FLAGS" in os.environ:
CMD_FLAGS = os.environ["OOBABOOGA_FLAGS"] CMD_FLAGS = os.environ["OOBABOOGA_FLAGS"]
@ -23,42 +28,28 @@ else:
CMD_FLAGS = '' CMD_FLAGS = ''
# Remove the '# ' from the following lines as needed for your AMD GPU on Linux def is_linux():
# os.environ["ROCM_PATH"] = '/opt/rocm' return sys.platform.startswith("linux")
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0'
# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030'
def print_big_message(message): def is_windows():
message = message.strip() return sys.platform.startswith("win")
lines = message.split('\n')
print("\n\n*******************************************************************")
for line in lines:
if line.strip() != '':
print("*", line)
print("*******************************************************************\n\n")
def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None): def is_macos():
# Use the conda environment return sys.platform.startswith("darwin")
if environment:
if sys.platform.startswith("win"):
conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat") def is_installed():
cmd = "\"" + conda_bat_path + "\" activate \"" + conda_env_path + "\" >nul && " + cmd for sitedir in site.getsitepackages():
if "site-packages" in sitedir and conda_env_path in sitedir:
site_packages_path = sitedir
break
if site_packages_path:
return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py'))
else: else:
conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh") return os.path.isdir(conda_env_path)
cmd = ". \"" + conda_sh_path + "\" && conda activate \"" + conda_env_path + "\" && " + cmd
# Run shell commands
result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env)
# Assert the command ran successfully
if assert_success and result.returncode != 0:
print("Command '" + cmd + "' failed with exit status code '" + str(result.returncode) + "'. Exiting...")
sys.exit()
return result
def check_env(): def check_env():
@ -78,19 +69,41 @@ def clear_cache():
run_cmd("conda clean -a -y", environment=True) run_cmd("conda clean -a -y", environment=True)
run_cmd("python -m pip cache purge", environment=True) run_cmd("python -m pip cache purge", environment=True)
def is_installed():
for sitedir in site.getsitepackages():
if "site-packages" in sitedir and conda_env_path in sitedir:
site_packages_path = sitedir
break
if site_packages_path: def print_big_message(message):
return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py')) message = message.strip()
lines = message.split('\n')
print("\n\n*******************************************************************")
for line in lines:
if line.strip() != '':
print("*", line)
print("*******************************************************************\n\n")
def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None):
# Use the conda environment
if environment:
if is_windows():
conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat")
cmd = "\"" + conda_bat_path + "\" activate \"" + conda_env_path + "\" >nul && " + cmd
else: else:
return os.path.isdir(conda_env_path) conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh")
cmd = ". \"" + conda_sh_path + "\" && conda activate \"" + conda_env_path + "\" && " + cmd
# Run shell commands
result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env)
# Assert the command ran successfully
if assert_success and result.returncode != 0:
print("Command '" + cmd + "' failed with exit status code '" + str(result.returncode) + "'. Exiting...")
sys.exit()
return result
def install_dependencies(): def install_dependencies():
# Select your GPU or, choose to run in CPU mode
print("What is your GPU") print("What is your GPU")
print() print()
print("A) NVIDIA") print("A) NVIDIA")
@ -98,27 +111,28 @@ def install_dependencies():
print("C) Apple M Series") print("C) Apple M Series")
print("D) None (I want to run models in CPU mode)") print("D) None (I want to run models in CPU mode)")
print() print()
# Select your GPU, or choose to run in CPU mode
gpuchoice = input("Input> ").lower() gpuchoice = input("Input> ").lower()
while gpuchoice not in ['a', 'b', 'c', 'd']: while gpuchoice not in ['a', 'b', 'c', 'd']:
print("Invalid choice. Please try again.") print("Invalid choice. Please try again.")
gpuchoice = input("Input> ").lower() gpuchoice = input("Input> ").lower()
if gpuchoice == "d": if gpuchoice == "d":
print_big_message("Once the installation ends, make sure to open CMD_FLAGS.txt with\na text editor and add the --cpu flag.") print_big_message("Once the installation ends, make sure to open CMD_FLAGS.txt with\na text editor and add the --cpu flag.")
# Install the version of PyTorch needed # Install Pytorch
if gpuchoice == "a": if gpuchoice == "a":
run_cmd('conda install -y -k cuda ninja git -c nvidia/label/cuda-11.7.0 -c nvidia && python -m pip install torch==2.0.1+cu117 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117', assert_success=True, environment=True) run_cmd('conda install -y -k cuda ninja git -c nvidia/label/cuda-11.7.0 -c nvidia && python -m pip install torch==2.0.1+cu117 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117', assert_success=True, environment=True)
elif gpuchoice == "b" and not sys.platform.startswith("darwin"): elif gpuchoice == "b" and not is_macos():
if sys.platform.startswith("linux"): if is_linux():
run_cmd('conda install -y -k ninja git && python -m pip install torch==2.0.1+rocm5.4.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2', assert_success=True, environment=True) run_cmd('conda install -y -k ninja git && python -m pip install torch==2.0.1+rocm5.4.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2', assert_success=True, environment=True)
else: else:
print("AMD GPUs are only supported on Linux. Exiting...") print("AMD GPUs are only supported on Linux. Exiting...")
sys.exit() sys.exit()
elif (gpuchoice == "c" or gpuchoice == "b") and sys.platform.startswith("darwin"): elif (gpuchoice == "c" or gpuchoice == "b") and is_macos():
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True) run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True)
elif gpuchoice == "d" or gpuchoice == "c": elif gpuchoice == "d" or gpuchoice == "c":
if sys.platform.startswith("linux"): if is_linux():
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu", assert_success=True, environment=True) run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu", assert_success=True, environment=True)
else: else:
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True) run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True)
@ -133,7 +147,7 @@ def update_dependencies(initial_installation=False):
git_creation_cmd = 'git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git remote set-head origin -a && git reset origin/HEAD && git branch --set-upstream-to=origin/HEAD' git_creation_cmd = 'git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git remote set-head origin -a && git reset origin/HEAD && git branch --set-upstream-to=origin/HEAD'
run_cmd(git_creation_cmd, environment=True, assert_success=True) run_cmd(git_creation_cmd, environment=True, assert_success=True)
run_cmd("git pull --autostash", assert_success=True, environment=True) # TODO is there a better way? run_cmd("git pull --autostash", assert_success=True, environment=True)
# Install the extensions dependencies (only on the first install) # Install the extensions dependencies (only on the first install)
if initial_installation: if initial_installation:
@ -196,11 +210,11 @@ def update_dependencies(initial_installation=False):
run_cmd("python -m pip install " + exllama_rocm, environment=True) run_cmd("python -m pip install " + exllama_rocm, environment=True)
# Fix JIT compile issue with exllama in Linux/WSL # Fix JIT compile issue with exllama in Linux/WSL
if sys.platform.startswith("linux") and not os.path.exists(f"{conda_env_path}/lib64"): if is_linux() and not os.path.exists(f"{conda_env_path}/lib64"):
run_cmd(f'ln -s "{conda_env_path}/lib" "{conda_env_path}/lib64"', environment=True) run_cmd(f'ln -s "{conda_env_path}/lib" "{conda_env_path}/lib64"', environment=True)
# On some Linux distributions, g++ may not exist or be the wrong version to compile GPTQ-for-LLaMa # On some Linux distributions, g++ may not exist or be the wrong version to compile GPTQ-for-LLaMa
if sys.platform.startswith("linux"): if is_linux():
gxx_output = run_cmd("g++ -dumpfullversion -dumpversion", environment=True, capture_output=True) gxx_output = run_cmd("g++ -dumpfullversion -dumpversion", environment=True, capture_output=True)
if gxx_output.returncode != 0 or int(gxx_output.stdout.strip().split(b".")[0]) > 11: if gxx_output.returncode != 0 or int(gxx_output.stdout.strip().split(b".")[0]) > 11:
# Install the correct version of g++ # Install the correct version of g++