From 89e0d15cf5c4b06967a31a51e40fd194a2bfc71a Mon Sep 17 00:00:00 2001 From: appe233 <89209249+appe233@users.noreply.github.com> Date: Tue, 18 Jul 2023 08:27:18 +0800 Subject: [PATCH] Use 'torch.backends.mps.is_available' to check if mps is supported (#3164) --- modules/LoRA.py | 2 +- modules/models.py | 4 ++-- modules/text_generation.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index 0626c969..1350783f 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -132,7 +132,7 @@ def add_lora_transformers(lora_names): if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): - if torch.has_mps: + if torch.backends.mps.is_available(): device = torch.device('mps') shared.model = shared.model.to(device) else: diff --git a/modules/models.py b/modules/models.py index 9d9ba951..232d5fa6 100644 --- a/modules/models.py +++ b/modules/models.py @@ -147,7 +147,7 @@ def huggingface_loader(model_name): # Load the model in simple 16-bit mode by default if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]): model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code) - if torch.has_mps: + if torch.backends.mps.is_available(): device = torch.device('mps') model = model.to(device) else: @@ -167,7 +167,7 @@ def huggingface_loader(model_name): "trust_remote_code": shared.args.trust_remote_code } - if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): + if not any((shared.args.cpu, torch.cuda.is_available(), torch.backends.mps.is_available())): logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.") shared.args.cpu = True diff --git a/modules/text_generation.py b/modules/text_generation.py index 566c2f55..d3939d3f 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -57,7 +57,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt return input_ids.numpy() elif shared.args.deepspeed: return input_ids.to(device=local_rank) - elif torch.has_mps: + elif torch.backends.mps.is_available(): device = torch.device('mps') return input_ids.to(device) else: