Installer: simplify the script

This commit is contained in:
oobabooga 2025-01-21 09:58:13 -08:00
parent 2bf8788c30
commit ff250dd800

View File

@ -102,31 +102,24 @@ def torch_version():
def update_pytorch():
print_big_message("Checking for PyTorch updates.")
torver = torch_version()
is_cuda = '+cu' in torver
is_cuda118 = '+cu118' in torver # 2.1.0+cu118
is_rocm = '+rocm' in torver # 2.0.1+rocm5.4.2
is_intel = '+cxx11' in torver # 2.0.1a0+cxx11.abi
is_cpu = '+cpu' in torver # 2.0.1+cpu
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}"
install_pytorch = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION} "
if "+cu118" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu118"
elif "+cu" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu121"
elif "+rocm" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.1"
elif "+cpu" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
elif "+cxx11" in torver:
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
install_cmd = f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
else:
install_cmd = base_cmd
if is_cuda118:
install_pytorch += "--index-url https://download.pytorch.org/whl/cu118"
elif is_cuda:
install_pytorch += "--index-url https://download.pytorch.org/whl/cu121"
elif is_rocm:
install_pytorch += "--index-url https://download.pytorch.org/whl/rocm6.1"
elif is_cpu:
install_pytorch += "--index-url https://download.pytorch.org/whl/cpu"
elif is_intel:
if is_linux():
install_pytorch = "python -m pip install --upgrade torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
else:
install_pytorch = "python -m pip install --upgrade torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
run_cmd(f"{install_pytorch}", assert_success=True, environment=True)
run_cmd(install_cmd, assert_success=True, environment=True)
def is_installed():
@ -340,69 +333,63 @@ def install_extensions_requirements():
def update_requirements(initial_installation=False, pull=True):
# Create .git directory if missing
if not os.path.exists(os.path.join(script_dir, ".git")):
git_creation_cmd = 'git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git symbolic-ref refs/remotes/origin/HEAD refs/remotes/origin/main && git reset --hard origin/main && git branch --set-upstream-to=origin/main'
run_cmd(git_creation_cmd, environment=True, assert_success=True)
run_cmd(
"git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && "
"git fetch && git symbolic-ref refs/remotes/origin/HEAD refs/remotes/origin/main && "
"git reset --hard origin/main && git branch --set-upstream-to=origin/main",
environment=True,
assert_success=True
)
# Detect the requirements file from the PyTorch version
torver = torch_version()
is_cuda = '+cu' in torver
is_cuda118 = '+cu118' in torver # 2.1.0+cu118
is_rocm = '+rocm' in torver # 2.0.1+rocm5.4.2
is_intel = '+cxx11' in torver # 2.0.1a0+cxx11.abi
is_cpu = '+cpu' in torver # 2.0.1+cpu
if is_rocm:
base_requirements = "requirements_amd" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
elif is_cpu or is_intel:
base_requirements = "requirements_cpu_only" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
if "+rocm" in torver:
requirements_file = "requirements_amd" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
elif "+cpu" in torver or "+cxx11" in torver:
requirements_file = "requirements_cpu_only" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
elif is_macos():
base_requirements = "requirements_apple_" + ("intel" if is_x86_64() else "silicon") + ".txt"
requirements_file = "requirements_apple_" + ("intel" if is_x86_64() else "silicon") + ".txt"
else:
base_requirements = "requirements" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
requirements_file = "requirements" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
requirements_file = base_requirements
# Call git pull, while checking if .whl requirements have changed
wheels_changed_from_flag = False
if os.path.exists('.wheels_changed_flag'):
# Check and clear the wheels changed flag
wheels_changed = os.path.exists('.wheels_changed_flag')
if wheels_changed:
os.remove('.wheels_changed_flag')
wheels_changed_from_flag = True
if pull:
# Read .whl lines before pulling
before_pull_whl_lines = []
if os.path.exists(requirements_file):
with open(requirements_file, 'r') as f:
before_pull_whl_lines = [line for line in f if '.whl' in line]
print_big_message("Updating the local copy of the repository with \"git pull\"")
print_big_message('Updating the local copy of the repository with "git pull"')
# Hash files before pulling
files_to_check = [
'start_linux.sh', 'start_macos.sh', 'start_windows.bat', 'start_wsl.bat',
'update_wizard_linux.sh', 'update_wizard_macos.sh', 'update_wizard_windows.bat', 'update_wizard_wsl.bat',
'one_click.py'
]
before_hashes = {file: calculate_file_hash(file) for file in files_to_check}
before_pull_hashes = {file_name: calculate_file_hash(file_name) for file_name in files_to_check}
# Perform the git pull
run_cmd("git pull --autostash", assert_success=True, environment=True)
after_pull_hashes = {file_name: calculate_file_hash(file_name) for file_name in files_to_check}
# Check hashes after pulling
after_hashes = {file: calculate_file_hash(file) for file in files_to_check}
if os.path.exists(requirements_file):
with open(requirements_file, 'r') as f:
after_pull_whl_lines = [line for line in f if '.whl' in line]
# Check for differences in installation file hashes
for file_name in files_to_check:
if before_pull_hashes[file_name] != after_pull_hashes[file_name]:
print_big_message(f"File '{file_name}' was updated during 'git pull'. Please run the script again.")
# Check if wheels changed during this pull
wheels_changed = before_pull_whl_lines != after_pull_whl_lines
if wheels_changed:
# Check for changes
for file in files_to_check:
if before_hashes[file] != after_hashes[file]:
print_big_message(f"File '{file}' was updated during 'git pull'. Please run the script again.")
if before_pull_whl_lines != after_pull_whl_lines:
open('.wheels_changed_flag', 'w').close()
exit(1)
wheels_changed = wheels_changed_from_flag
if pull:
wheels_changed = wheels_changed or (before_pull_whl_lines != after_pull_whl_lines)
if os.environ.get("INSTALL_EXTENSIONS", "").lower() in ("yes", "y", "true", "1", "t", "on"):
@ -419,16 +406,16 @@ def update_requirements(initial_installation=False, pull=True):
textgen_requirements = open(requirements_file).read().splitlines()
if not initial_installation and not wheels_changed:
textgen_requirements = [line for line in textgen_requirements if not '.whl' in line]
textgen_requirements = [line for line in textgen_requirements if '.whl' not in line]
if is_cuda118:
if "+cu118" in torver:
textgen_requirements = [
req.replace('+cu121', '+cu118').replace('+cu122', '+cu118')
for req in textgen_requirements
if "autoawq" not in req.lower()
]
if is_windows() and is_cuda118: # No flash-attention on Windows for CUDA 11
if is_windows() and "+cu118" in torver: # No flash-attention on Windows for CUDA 11
textgen_requirements = [req for req in textgen_requirements if 'oobabooga/flash-attention' not in req]
with open('temp_requirements.txt', 'w') as file: