diff --git a/modules/LoRA.py b/modules/LoRA.py index aa68ad32..283fcf4c 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,5 +1,7 @@ from pathlib import Path +import torch + import modules.shared as shared from modules.models import load_model from modules.text_generation import clear_torch_cache @@ -34,4 +36,8 @@ def add_lora_to_model(lora_name): if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): - shared.model.cuda() + if torch.has_mps: + device = torch.device('mps') + shared.model = shared.model.to(device) + else: + shared.model = shared.model.cuda()