Installer: do not redownload wheels for each update

This commit is contained in:
oobabooga 2025-01-21 08:35:35 -08:00
parent f8a5b0bc43
commit ecb5d3c485

View File

@ -101,7 +101,7 @@ def torch_version():
def update_pytorch():
print_big_message("Checking for PyTorch updates")
print_big_message("Checking for PyTorch updates.")
torver = torch_version()
is_cuda = '+cu' in torver
@ -343,6 +343,31 @@ def update_requirements(initial_installation=False, pull=True):
git_creation_cmd = 'git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && git fetch && git symbolic-ref refs/remotes/origin/HEAD refs/remotes/origin/main && git reset --hard origin/main && git branch --set-upstream-to=origin/main'
run_cmd(git_creation_cmd, environment=True, assert_success=True)
# Detect the requirements file from the PyTorch version
torver = torch_version()
is_cuda = '+cu' in torver
is_cuda118 = '+cu118' in torver # 2.1.0+cu118
is_rocm = '+rocm' in torver # 2.0.1+rocm5.4.2
is_intel = '+cxx11' in torver # 2.0.1a0+cxx11.abi
is_cpu = '+cpu' in torver # 2.0.1+cpu
if is_rocm:
base_requirements = "requirements_amd" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
elif is_cpu or is_intel:
base_requirements = "requirements_cpu_only" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
elif is_macos():
base_requirements = "requirements_apple_" + ("intel" if is_x86_64() else "silicon") + ".txt"
else:
base_requirements = "requirements" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
requirements_file = base_requirements
# Call git pull
before_pull_whl_lines = []
if os.path.exists(requirements_file):
with open(requirements_file, 'r') as f:
before_pull_whl_lines = [line for line in f if '.whl' in line]
if pull:
print_big_message("Updating the local copy of the repository with \"git pull\"")
@ -362,6 +387,11 @@ def update_requirements(initial_installation=False, pull=True):
print_big_message(f"File '{file_name}' was updated during 'git pull'. Please run the script again.")
exit(1)
after_pull_whl_lines = []
if os.path.exists(requirements_file):
with open(requirements_file, 'r') as f:
after_pull_whl_lines = [line for line in f if '.whl' in line]
if os.environ.get("INSTALL_EXTENSIONS", "").lower() in ("yes", "y", "true", "1", "t", "on"):
install_extensions_requirements()
@ -369,30 +399,16 @@ def update_requirements(initial_installation=False, pull=True):
if not initial_installation:
update_pytorch()
# Detect the PyTorch version
torver = torch_version()
is_cuda = '+cu' in torver
is_cuda118 = '+cu118' in torver # 2.1.0+cu118
is_rocm = '+rocm' in torver # 2.0.1+rocm5.4.2
is_intel = '+cxx11' in torver # 2.0.1a0+cxx11.abi
is_cpu = '+cpu' in torver # 2.0.1+cpu
if is_rocm:
base_requirements = "requirements_amd" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
elif is_cpu or is_intel:
base_requirements = "requirements_cpu_only" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
elif is_macos():
base_requirements = "requirements_apple_" + ("intel" if is_x86_64() else "silicon") + ".txt"
else:
base_requirements = "requirements" + ("_noavx2" if not cpu_has_avx2() else "") + ".txt"
requirements_file = base_requirements
print_big_message(f"Installing webui requirements from file: {requirements_file}")
print(f"TORCH: {torver}\n")
# Prepare the requirements file
textgen_requirements = open(requirements_file).read().splitlines()
whl_changed = before_pull_whl_lines != after_pull_whl_lines
if not initial_installation and not whl_changed:
textgen_requirements = [line for line in textgen_requirements if not '.whl' in line]
if is_cuda118:
textgen_requirements = [
req.replace('+cu121', '+cu118').replace('+cu122', '+cu118')
@ -416,16 +432,9 @@ def update_requirements(initial_installation=False, pull=True):
# Install/update the project requirements
run_cmd("python -m pip install -r temp_requirements.txt --upgrade", assert_success=True, environment=True)
# Clean up
os.remove('temp_requirements.txt')
# Check for '+cu' or '+rocm' in version string to determine if torch uses CUDA or ROCm. Check for pytorch-cuda as well for backwards compatibility
if not any((is_cuda, is_rocm)) and run_cmd("conda list -f pytorch-cuda | grep pytorch-cuda", environment=True, capture_output=True).returncode == 1:
clear_cache()
return
if not os.path.exists("repositories/"):
os.mkdir("repositories")
clear_cache()