Minor changes

This commit is contained in:
oobabooga 2023-06-20 20:23:21 -03:00 committed by GitHub
parent 5cbc0b28f2
commit c0a1baa46e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,6 @@
import argparse import argparse
import glob import glob
import os import os
import shutil
import site import site
import subprocess import subprocess
import sys import sys
@ -72,6 +71,7 @@ def install_dependencies():
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
# punctuation contains: !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ # punctuation contains: !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
from string import punctuation from string import punctuation
# Allow some characters: _-:\/.'" # Allow some characters: _-:\/.'"
special_characters = punctuation.translate({ord(char): None for char in '_-:\\/.\'"'}) special_characters = punctuation.translate({ord(char): None for char in '_-:\\/.\'"'})
if any(char in script_dir for char in special_characters): if any(char in script_dir for char in special_characters):
@ -117,13 +117,13 @@ def update_dependencies():
with open("requirements.txt") as f: with open("requirements.txt") as f:
requirements = f.read().splitlines() requirements = f.read().splitlines()
git_requirements = [req for req in requirements if req.startswith("git+")] git_requirements = [req for req in requirements if req.startswith("git+")]
# Loop through each "git+" requirement and uninstall it # Loop through each "git+" requirement and uninstall it
for req in git_requirements: for req in git_requirements:
# Extract the package name from the "git+" requirement # Extract the package name from the "git+" requirement
url = req.replace("git+", "") url = req.replace("git+", "")
package_name = url.split("/")[-1].split("@")[0] package_name = url.split("/")[-1].split("@")[0]
# Uninstall the package using pip # Uninstall the package using pip
run_cmd("python -m pip uninstall " + package_name, environment=True) run_cmd("python -m pip uninstall " + package_name, environment=True)
print(f"Uninstalled {package_name}") print(f"Uninstalled {package_name}")
@ -159,7 +159,7 @@ def update_dependencies():
# Parse output of 'pip show torch' to determine torch version # Parse output of 'pip show torch' to determine torch version
torver_cmd = run_cmd("python -m pip show torch", assert_success=True, environment=True, capture_output=True) torver_cmd = run_cmd("python -m pip show torch", assert_success=True, environment=True, capture_output=True)
torver = [v.split()[1] for v in torver_cmd.stdout.decode('utf-8').splitlines() if 'Version:' in v][0] torver = [v.split()[1] for v in torver_cmd.stdout.decode('utf-8').splitlines() if 'Version:' in v][0]
# Check for '+cu' in version string to determine if torch uses CUDA or not check for pytorch-cuda as well for backwards compatibility # Check for '+cu' in version string to determine if torch uses CUDA or not check for pytorch-cuda as well for backwards compatibility
if '+cu' not in torver and run_cmd("conda list -f pytorch-cuda | grep pytorch-cuda", environment=True, capture_output=True).returncode == 1: if '+cu' not in torver and run_cmd("conda list -f pytorch-cuda | grep pytorch-cuda", environment=True, capture_output=True).returncode == 1:
return return
@ -183,7 +183,7 @@ def update_dependencies():
os.mkdir("repositories") os.mkdir("repositories")
os.chdir("repositories") os.chdir("repositories")
# Install or update exllama as needed # Install or update exllama as needed
if not os.path.exists("exllama/"): if not os.path.exists("exllama/"):
run_cmd("git clone https://github.com/turboderp/exllama.git", environment=True) run_cmd("git clone https://github.com/turboderp/exllama.git", environment=True)
@ -191,11 +191,11 @@ def update_dependencies():
os.chdir("exllama") os.chdir("exllama")
run_cmd("git pull", environment=True) run_cmd("git pull", environment=True)
os.chdir("..") os.chdir("..")
# Fix build issue with exllama in Linux/WSL # Fix build issue with exllama in Linux/WSL
if sys.platform.startswith("linux") and not os.path.exists(f"{conda_env_path}/lib64"): if sys.platform.startswith("linux") and not os.path.exists(f"{conda_env_path}/lib64"):
run_cmd(f'ln -s "{conda_env_path}/lib" "{conda_env_path}/lib64"', environment=True) run_cmd(f'ln -s "{conda_env_path}/lib" "{conda_env_path}/lib64"', environment=True)
# Install GPTQ-for-LLaMa which enables 4bit CUDA quantization # Install GPTQ-for-LLaMa which enables 4bit CUDA quantization
if not os.path.exists("GPTQ-for-LLaMa/"): if not os.path.exists("GPTQ-for-LLaMa/"):
run_cmd("git clone https://github.com/oobabooga/GPTQ-for-LLaMa.git -b cuda", assert_success=True, environment=True) run_cmd("git clone https://github.com/oobabooga/GPTQ-for-LLaMa.git -b cuda", assert_success=True, environment=True)