mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Minor changes / reorder some functions
This commit is contained in:
parent
84b5a519cb
commit
66363a4d70
112
webui.py
112
webui.py
@ -8,6 +8,11 @@ import sys
|
|||||||
script_dir = os.getcwd()
|
script_dir = os.getcwd()
|
||||||
conda_env_path = os.path.join(script_dir, "installer_files", "env")
|
conda_env_path = os.path.join(script_dir, "installer_files", "env")
|
||||||
|
|
||||||
|
# Remove the '# ' from the following lines as needed for your AMD GPU on Linux
|
||||||
|
# os.environ["ROCM_PATH"] = '/opt/rocm'
|
||||||
|
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0'
|
||||||
|
# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030'
|
||||||
|
|
||||||
# Command-line flags
|
# Command-line flags
|
||||||
if "OOBABOOGA_FLAGS" in os.environ:
|
if "OOBABOOGA_FLAGS" in os.environ:
|
||||||
CMD_FLAGS = os.environ["OOBABOOGA_FLAGS"]
|
CMD_FLAGS = os.environ["OOBABOOGA_FLAGS"]
|
||||||
@ -23,42 +28,28 @@ else:
|
|||||||
CMD_FLAGS = ''
|
CMD_FLAGS = ''
|
||||||
|
|
||||||
|
|
||||||
# Remove the '# ' from the following lines as needed for your AMD GPU on Linux
|
def is_linux():
|
||||||
# os.environ["ROCM_PATH"] = '/opt/rocm'
|
return sys.platform.startswith("linux")
|
||||||
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0'
|
|
||||||
# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030'
|
|
||||||
|
|
||||||
|
|
||||||
def print_big_message(message):
|
def is_windows():
|
||||||
message = message.strip()
|
return sys.platform.startswith("win")
|
||||||
lines = message.split('\n')
|
|
||||||
print("\n\n*******************************************************************")
|
|
||||||
for line in lines:
|
|
||||||
if line.strip() != '':
|
|
||||||
print("*", line)
|
|
||||||
|
|
||||||
print("*******************************************************************\n\n")
|
|
||||||
|
|
||||||
|
|
||||||
def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None):
|
def is_macos():
|
||||||
# Use the conda environment
|
return sys.platform.startswith("darwin")
|
||||||
if environment:
|
|
||||||
if sys.platform.startswith("win"):
|
|
||||||
conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat")
|
|
||||||
cmd = "\"" + conda_bat_path + "\" activate \"" + conda_env_path + "\" >nul && " + cmd
|
|
||||||
else:
|
|
||||||
conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh")
|
|
||||||
cmd = ". \"" + conda_sh_path + "\" && conda activate \"" + conda_env_path + "\" && " + cmd
|
|
||||||
|
|
||||||
# Run shell commands
|
|
||||||
result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env)
|
|
||||||
|
|
||||||
# Assert the command ran successfully
|
def is_installed():
|
||||||
if assert_success and result.returncode != 0:
|
for sitedir in site.getsitepackages():
|
||||||
print("Command '" + cmd + "' failed with exit status code '" + str(result.returncode) + "'. Exiting...")
|
if "site-packages" in sitedir and conda_env_path in sitedir:
|
||||||
sys.exit()
|
site_packages_path = sitedir
|
||||||
|
break
|
||||||
|
|
||||||
return result
|
if site_packages_path:
|
||||||
|
return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py'))
|
||||||
|
else:
|
||||||
|
return os.path.isdir(conda_env_path)
|
||||||
|
|
||||||
|
|
||||||
def check_env():
|
def check_env():
|
||||||
@ -78,19 +69,41 @@ def clear_cache():
|
|||||||
run_cmd("conda clean -a -y", environment=True)
|
run_cmd("conda clean -a -y", environment=True)
|
||||||
run_cmd("python -m pip cache purge", environment=True)
|
run_cmd("python -m pip cache purge", environment=True)
|
||||||
|
|
||||||
def is_installed():
|
|
||||||
for sitedir in site.getsitepackages():
|
|
||||||
if "site-packages" in sitedir and conda_env_path in sitedir:
|
|
||||||
site_packages_path = sitedir
|
|
||||||
break
|
|
||||||
|
|
||||||
if site_packages_path:
|
def print_big_message(message):
|
||||||
return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py'))
|
message = message.strip()
|
||||||
else:
|
lines = message.split('\n')
|
||||||
return os.path.isdir(conda_env_path)
|
print("\n\n*******************************************************************")
|
||||||
|
for line in lines:
|
||||||
|
if line.strip() != '':
|
||||||
|
print("*", line)
|
||||||
|
|
||||||
|
print("*******************************************************************\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None):
|
||||||
|
# Use the conda environment
|
||||||
|
if environment:
|
||||||
|
if is_windows():
|
||||||
|
conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat")
|
||||||
|
cmd = "\"" + conda_bat_path + "\" activate \"" + conda_env_path + "\" >nul && " + cmd
|
||||||
|
else:
|
||||||
|
conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh")
|
||||||
|
cmd = ". \"" + conda_sh_path + "\" && conda activate \"" + conda_env_path + "\" && " + cmd
|
||||||
|
|
||||||
|
# Run shell commands
|
||||||
|
result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env)
|
||||||
|
|
||||||
|
# Assert the command ran successfully
|
||||||
|
if assert_success and result.returncode != 0:
|
||||||
|
print("Command '" + cmd + "' failed with exit status code '" + str(result.returncode) + "'. Exiting...")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def install_dependencies():
|
def install_dependencies():
|
||||||
# Select your GPU or, choose to run in CPU mode
|
|
||||||
print("What is your GPU")
|
print("What is your GPU")
|
||||||
print()
|
print()
|
||||||
print("A) NVIDIA")
|
print("A) NVIDIA")
|
||||||
@ -98,27 +111,28 @@ def install_dependencies():
|
|||||||
print("C) Apple M Series")
|
print("C) Apple M Series")
|
||||||
print("D) None (I want to run models in CPU mode)")
|
print("D) None (I want to run models in CPU mode)")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
# Select your GPU, or choose to run in CPU mode
|
||||||
gpuchoice = input("Input> ").lower()
|
gpuchoice = input("Input> ").lower()
|
||||||
while gpuchoice not in ['a', 'b', 'c', 'd']:
|
while gpuchoice not in ['a', 'b', 'c', 'd']:
|
||||||
print("Invalid choice. Please try again.")
|
print("Invalid choice. Please try again.")
|
||||||
gpuchoice = input("Input> ").lower()
|
gpuchoice = input("Input> ").lower()
|
||||||
|
|
||||||
if gpuchoice == "d":
|
if gpuchoice == "d":
|
||||||
print_big_message("Once the installation ends, make sure to open CMD_FLAGS.txt with\na text editor and add the --cpu flag.")
|
print_big_message("Once the installation ends, make sure to open CMD_FLAGS.txt with\na text editor and add the --cpu flag.")
|
||||||
|
|
||||||
# Install the version of PyTorch needed
|
# Install Pytorch
|
||||||
if gpuchoice == "a":
|
if gpuchoice == "a":
|
||||||
run_cmd('conda install -y -k cuda ninja git -c nvidia/label/cuda-11.7.0 -c nvidia && python -m pip install torch==2.0.1+cu117 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117', assert_success=True, environment=True)
|
run_cmd('conda install -y -k cuda ninja git -c nvidia/label/cuda-11.7.0 -c nvidia && python -m pip install torch==2.0.1+cu117 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117', assert_success=True, environment=True)
|
||||||
elif gpuchoice == "b" and not sys.platform.startswith("darwin"):
|
elif gpuchoice == "b" and not is_macos():
|
||||||
if sys.platform.startswith("linux"):
|
if is_linux():
|
||||||
run_cmd('conda install -y -k ninja git && python -m pip install torch==2.0.1+rocm5.4.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2', assert_success=True, environment=True)
|
run_cmd('conda install -y -k ninja git && python -m pip install torch==2.0.1+rocm5.4.2 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2', assert_success=True, environment=True)
|
||||||
else:
|
else:
|
||||||
print("AMD GPUs are only supported on Linux. Exiting...")
|
print("AMD GPUs are only supported on Linux. Exiting...")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
elif (gpuchoice == "c" or gpuchoice == "b") and sys.platform.startswith("darwin"):
|
elif (gpuchoice == "c" or gpuchoice == "b") and is_macos():
|
||||||
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True)
|
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True)
|
||||||
elif gpuchoice == "d" or gpuchoice == "c":
|
elif gpuchoice == "d" or gpuchoice == "c":
|
||||||
if sys.platform.startswith("linux"):
|
if is_linux():
|
||||||
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu", assert_success=True, environment=True)
|
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu", assert_success=True, environment=True)
|
||||||
else:
|
else:
|
||||||
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True)
|
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True)
|
||||||
@ -132,8 +146,8 @@ def update_dependencies(initial_installation=False):
|
|||||||
if not os.path.isdir(os.path.join(script_dir, ".git")):
|
if not os.path.isdir(os.path.join(script_dir, ".git")):
|
||||||
git_creation_cmd = 'git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git remote set-head origin -a && git reset origin/HEAD && git branch --set-upstream-to=origin/HEAD'
|
git_creation_cmd = 'git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git remote set-head origin -a && git reset origin/HEAD && git branch --set-upstream-to=origin/HEAD'
|
||||||
run_cmd(git_creation_cmd, environment=True, assert_success=True)
|
run_cmd(git_creation_cmd, environment=True, assert_success=True)
|
||||||
|
|
||||||
run_cmd("git pull --autostash", assert_success=True, environment=True) # TODO is there a better way?
|
run_cmd("git pull --autostash", assert_success=True, environment=True)
|
||||||
|
|
||||||
# Install the extensions dependencies (only on the first install)
|
# Install the extensions dependencies (only on the first install)
|
||||||
if initial_installation:
|
if initial_installation:
|
||||||
@ -196,11 +210,11 @@ def update_dependencies(initial_installation=False):
|
|||||||
run_cmd("python -m pip install " + exllama_rocm, environment=True)
|
run_cmd("python -m pip install " + exllama_rocm, environment=True)
|
||||||
|
|
||||||
# Fix JIT compile issue with exllama in Linux/WSL
|
# Fix JIT compile issue with exllama in Linux/WSL
|
||||||
if sys.platform.startswith("linux") and not os.path.exists(f"{conda_env_path}/lib64"):
|
if is_linux() and not os.path.exists(f"{conda_env_path}/lib64"):
|
||||||
run_cmd(f'ln -s "{conda_env_path}/lib" "{conda_env_path}/lib64"', environment=True)
|
run_cmd(f'ln -s "{conda_env_path}/lib" "{conda_env_path}/lib64"', environment=True)
|
||||||
|
|
||||||
# On some Linux distributions, g++ may not exist or be the wrong version to compile GPTQ-for-LLaMa
|
# On some Linux distributions, g++ may not exist or be the wrong version to compile GPTQ-for-LLaMa
|
||||||
if sys.platform.startswith("linux"):
|
if is_linux():
|
||||||
gxx_output = run_cmd("g++ -dumpfullversion -dumpversion", environment=True, capture_output=True)
|
gxx_output = run_cmd("g++ -dumpfullversion -dumpversion", environment=True, capture_output=True)
|
||||||
if gxx_output.returncode != 0 or int(gxx_output.stdout.strip().split(b".")[0]) > 11:
|
if gxx_output.returncode != 0 or int(gxx_output.stdout.strip().split(b".")[0]) > 11:
|
||||||
# Install the correct version of g++
|
# Install the correct version of g++
|
||||||
|
Loading…
Reference in New Issue
Block a user