diff --git a/server.py b/server.py index 1407b4c7..2b49c572 100644 --- a/server.py +++ b/server.py @@ -77,7 +77,7 @@ def load_model(model_name): t0 = time.time() # Default settings - if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None): + if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None): if Path(f"torch-dumps/{model_name}.pt").exists(): print("Loading in .pt format...") model = torch.load(Path(f"torch-dumps/{model_name}.pt"))