From b88b2b74a6a0cd01830b04a1dafe8803a2cd5bf2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 15 Oct 2023 20:51:11 -0700 Subject: [PATCH] Experimental Intel Arc transformers support (untested) --- modules/models.py | 2 ++ modules/text_generation.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/modules/models.py b/modules/models.py index 5bd9db74..c376c808 100644 --- a/modules/models.py +++ b/modules/models.py @@ -137,6 +137,8 @@ def huggingface_loader(model_name): if torch.backends.mps.is_available(): device = torch.device('mps') model = model.to(device) + elif hasattr(torch, 'xpu') and torch.xpu.is_available(): + model = model.to('xpu') else: model = model.cuda() diff --git a/modules/text_generation.py b/modules/text_generation.py index 0f24dc58..295c7cdd 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -132,6 +132,8 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt elif torch.backends.mps.is_available(): device = torch.device('mps') return input_ids.to(device) + elif hasattr(torch, 'xpu') and torch.xpu.is_available(): + return input_ids.to('xpu') else: return input_ids.cuda()