Fix wrong pytorch version on Linux+CPU

It was installing nvidia wheels
This commit is contained in:
oobabooga 2023-07-07 20:40:31 -03:00 committed by GitHub
parent 564a8c507f
commit bb79037ebd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,8 +87,14 @@ def install_dependencies():
elif gpuchoice == "b":
print("AMD GPUs are not supported. Exiting...")
sys.exit()
elif gpuchoice == "c" or gpuchoice == "d":
elif gpuchoice == "c":
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True)
elif gpuchoice == "d":
if sys.platform.startswith("linux"):
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu", assert_success=True, environment=True)
else:
run_cmd("conda install -y -k ninja git && python -m pip install torch torchvision torchaudio", assert_success=True, environment=True)
else:
print("Invalid choice. Exiting...")
sys.exit()