diff --git a/modules/models.py b/modules/models.py index 5bd9db74..c376c808 100644 --- a/modules/models.py +++ b/modules/models.py @@ -137,6 +137,8 @@ def huggingface_loader(model_name): if torch.backends.mps.is_available(): device = torch.device('mps') model = model.to(device) + elif hasattr(torch, 'xpu') and torch.xpu.is_available(): + model = model.to('xpu') else: model = model.cuda() diff --git a/modules/text_generation.py b/modules/text_generation.py index 0f24dc58..295c7cdd 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -132,6 +132,8 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt elif torch.backends.mps.is_available(): device = torch.device('mps') return input_ids.to(device) + elif hasattr(torch, 'xpu') and torch.xpu.is_available(): + return input_ids.to('xpu') else: return input_ids.cuda()