Added two ENVs in webui.py for docker (#111)

This commit is contained in:
mongolu 2023-09-23 04:43:11 +03:00 committed by oobabooga
parent 72b4ab4c82
commit d70b8d9048

View File

@ -98,19 +98,24 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False,
def install_webui(): def install_webui():
print("What is your GPU")
print()
print("A) NVIDIA")
print("B) AMD (Linux/MacOS only. Requires ROCm SDK 5.4.2/5.4.3 on Linux)")
print("C) Apple M Series")
print("D) None (I want to run models in CPU mode)")
print()
# Select your GPU, or choose to run in CPU mode # Select your GPU, or choose to run in CPU mode
choice = input("Input> ").upper() if "GPU_CHOICE" in os.environ:
while choice not in ['A', 'B', 'C', 'D']: choice = os.environ["GPU_CHOICE"].upper()
print("Invalid choice. Please try again.") print_big_message(f"Selected GPU choice \"{choice}\" based on the GPU_CHOICE environment variable.")
else:
print("What is your GPU")
print()
print("A) NVIDIA")
print("B) AMD (Linux/MacOS only. Requires ROCm SDK 5.4.2/5.4.3 on Linux)")
print("C) Apple M Series")
print("D) None (I want to run models in CPU mode)")
print()
choice = input("Input> ").upper() choice = input("Input> ").upper()
while choice not in ['A', 'B', 'C', 'D']:
print("Invalid choice. Please try again.")
choice = input("Input> ").upper()
if choice == "D": if choice == "D":
print_big_message("Once the installation ends, make sure to open CMD_FLAGS.txt with\na text editor and add the --cpu flag.") print_big_message("Once the installation ends, make sure to open CMD_FLAGS.txt with\na text editor and add the --cpu flag.")
@ -261,6 +266,10 @@ if __name__ == "__main__":
install_webui() install_webui()
os.chdir(script_dir) os.chdir(script_dir)
if os.environ.get("LAUNCH_AFTER_INSTALL", "").lower() in ("no", "n", "false", "0", "f", "off"):
print_big_message("Install finished successfully and will now exit due to LAUNCH_AFTER_INSTALL.")
sys.exit()
# Check if a model has been downloaded yet # Check if a model has been downloaded yet
if len([item for item in glob.glob('models/*') if not item.endswith(('.txt', '.yaml'))]) == 0: if len([item for item in glob.glob('models/*') if not item.endswith(('.txt', '.yaml'))]) == 0:
print_big_message("WARNING: You haven't downloaded any model yet.\nOnce the web UI launches, head over to the \"Model\" tab and download one.") print_big_message("WARNING: You haven't downloaded any model yet.\nOnce the web UI launches, head over to the \"Model\" tab and download one.")