mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Intel Gpu support initialization (#4340)
This commit is contained in:
parent
317e2c857e
commit
778a010df8
@ -3,6 +3,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from transformers import is_torch_xpu_available
|
||||||
|
|
||||||
|
|
||||||
class AbstractMultimodalPipeline(ABC):
|
class AbstractMultimodalPipeline(ABC):
|
||||||
@ -55,7 +56,7 @@ class AbstractMultimodalPipeline(ABC):
|
|||||||
|
|
||||||
def _get_device(self, setting_name: str, params: dict):
|
def _get_device(self, setting_name: str, params: dict):
|
||||||
if params[setting_name] is None:
|
if params[setting_name] is None:
|
||||||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
return torch.device("cuda:0" if torch.cuda.is_available() else "xpu:0" if is_torch_xpu_available() else "cpu")
|
||||||
return torch.device(params[setting_name])
|
return torch.device(params[setting_name])
|
||||||
|
|
||||||
def _get_dtype(self, setting_name: str, params: dict):
|
def _get_dtype(self, setting_name: str, params: dict):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from accelerate import is_xpu_available
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -41,7 +42,7 @@ def load_quantized(model_name):
|
|||||||
# Define the params for AutoGPTQForCausalLM.from_quantized
|
# Define the params for AutoGPTQForCausalLM.from_quantized
|
||||||
params = {
|
params = {
|
||||||
'model_basename': pt_path.stem,
|
'model_basename': pt_path.stem,
|
||||||
'device': "cuda:0" if not shared.args.cpu else "cpu",
|
'device': "xpu:0" if is_xpu_available() else "cuda:0" if not shared.args.cpu else "cpu",
|
||||||
'use_triton': shared.args.triton,
|
'use_triton': shared.args.triton,
|
||||||
'inject_fused_attention': not shared.args.no_inject_fused_attention,
|
'inject_fused_attention': not shared.args.no_inject_fused_attention,
|
||||||
'inject_fused_mlp': not shared.args.no_inject_fused_mlp,
|
'inject_fused_mlp': not shared.args.no_inject_fused_mlp,
|
||||||
|
@ -5,15 +5,15 @@ from pathlib import Path
|
|||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from accelerate import is_xpu_available
|
||||||
|
from gptq_for_llama import llama_inference_offload
|
||||||
|
from gptq_for_llama.modelutils import find_layers
|
||||||
|
from gptq_for_llama.quant import make_quant
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from gptq_for_llama import llama_inference_offload
|
|
||||||
from gptq_for_llama.modelutils import find_layers
|
|
||||||
from gptq_for_llama.quant import make_quant
|
|
||||||
|
|
||||||
|
|
||||||
# This function is a replacement for the load_quant function in the
|
# This function is a replacement for the load_quant function in the
|
||||||
# GPTQ-for_LLaMa repository. It supports more models and branches.
|
# GPTQ-for_LLaMa repository. It supports more models and branches.
|
||||||
@ -144,7 +144,7 @@ def load_quantized(model_name):
|
|||||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
|
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
|
||||||
|
|
||||||
# accelerate offload (doesn't work properly)
|
# accelerate offload (doesn't work properly)
|
||||||
if shared.args.gpu_memory or torch.cuda.device_count() > 1:
|
if shared.args.gpu_memory or torch.cuda.device_count() > 1 or (is_xpu_available() and torch.xpu.device_count() > 1):
|
||||||
if shared.args.gpu_memory:
|
if shared.args.gpu_memory:
|
||||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||||
@ -163,6 +163,9 @@ def load_quantized(model_name):
|
|||||||
|
|
||||||
# No offload
|
# No offload
|
||||||
elif not shared.args.cpu:
|
elif not shared.args.cpu:
|
||||||
|
if is_xpu_available():
|
||||||
|
model = model.to(torch.device("xpu:0"))
|
||||||
|
else:
|
||||||
model = model.to(torch.device('cuda:0'))
|
model = model.to(torch.device('cuda:0'))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -2,6 +2,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
from transformers import is_torch_xpu_available
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
@ -179,6 +180,9 @@ def add_lora_transformers(lora_names):
|
|||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
device = torch.device('mps')
|
device = torch.device('mps')
|
||||||
shared.model = shared.model.to(device)
|
shared.model = shared.model.to(device)
|
||||||
|
elif is_torch_xpu_available():
|
||||||
|
device = torch.device("xpu:0")
|
||||||
|
shared.model = shared.model.to(device)
|
||||||
else:
|
else:
|
||||||
shared.model = shared.model.cuda()
|
shared.model = shared.model.cuda()
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
from transformers import is_torch_xpu_available
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.callbacks import Iteratorize
|
from modules.callbacks import Iteratorize
|
||||||
@ -27,7 +28,7 @@ class RWKVModel:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(self, path, dtype="fp16", device="cuda"):
|
def from_pretrained(self, path, dtype="bf16" if is_torch_xpu_available() else "fp16", device="xpu" if is_torch_xpu_available() else "cuda"):
|
||||||
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
||||||
if shared.args.rwkv_strategy is None:
|
if shared.args.rwkv_strategy is None:
|
||||||
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
|
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
|
||||||
|
@ -5,6 +5,7 @@ from threading import Thread
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers import is_torch_xpu_available
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
@ -92,4 +93,7 @@ class Iteratorize:
|
|||||||
def clear_torch_cache():
|
def clear_torch_cache():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
else:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from transformers import is_torch_xpu_available
|
||||||
|
|
||||||
from modules import sampler_hijack, shared
|
from modules import sampler_hijack, shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
@ -32,11 +33,17 @@ def get_next_logits(prompt, state, use_samplers, previous):
|
|||||||
scores = sampler_hijack.global_scores[-1]
|
scores = sampler_hijack.global_scores[-1]
|
||||||
else:
|
else:
|
||||||
if is_non_hf_exllamav2 or is_non_hf_exllamav1:
|
if is_non_hf_exllamav2 or is_non_hf_exllamav1:
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
tokens = shared.tokenizer.encode(prompt).to("xpu:0")
|
||||||
|
else:
|
||||||
tokens = shared.tokenizer.encode(prompt).cuda()
|
tokens = shared.tokenizer.encode(prompt).cuda()
|
||||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||||
elif is_non_hf_llamacpp:
|
elif is_non_hf_llamacpp:
|
||||||
tokens = shared.tokenizer.encode(prompt)
|
tokens = shared.tokenizer.encode(prompt)
|
||||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||||
|
else:
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0")
|
||||||
else:
|
else:
|
||||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
||||||
output = shared.model(input_ids=tokens)
|
output = shared.model(input_ids=tokens)
|
||||||
|
@ -7,7 +7,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import (
|
||||||
|
infer_auto_device_map,
|
||||||
|
init_empty_weights,
|
||||||
|
is_ccl_available,
|
||||||
|
is_xpu_available
|
||||||
|
)
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
@ -38,6 +43,10 @@ if shared.args.deepspeed:
|
|||||||
# Distributed setup
|
# Distributed setup
|
||||||
local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
|
local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
if is_xpu_available() and is_ccl_available():
|
||||||
|
torch.xpu.set_device(local_rank)
|
||||||
|
deepspeed.init_distributed(backend="ccl")
|
||||||
|
else:
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
deepspeed.init_distributed()
|
deepspeed.init_distributed()
|
||||||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||||
@ -137,8 +146,9 @@ def huggingface_loader(model_name):
|
|||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
device = torch.device('mps')
|
device = torch.device('mps')
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
elif is_xpu_available():
|
||||||
model = model.to('xpu')
|
device = torch.device("xpu")
|
||||||
|
model = model.to(device)
|
||||||
else:
|
else:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
@ -151,15 +161,10 @@ def huggingface_loader(model_name):
|
|||||||
|
|
||||||
# Load with quantization and/or offloading
|
# Load with quantization and/or offloading
|
||||||
else:
|
else:
|
||||||
conditions = [
|
|
||||||
shared.args.cpu,
|
|
||||||
torch.cuda.is_available(),
|
|
||||||
torch.backends.mps.is_available(),
|
|
||||||
hasattr(torch, 'xpu') and torch.xpu.is_available(),
|
|
||||||
]
|
|
||||||
|
|
||||||
if not any(conditions):
|
if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())):
|
||||||
logger.warning('No GPU has been detected by Pytorch. Falling back to CPU mode.')
|
logger.warning('torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.')
|
||||||
|
|
||||||
shared.args.cpu = True
|
shared.args.cpu = True
|
||||||
|
|
||||||
if shared.args.cpu:
|
if shared.args.cpu:
|
||||||
@ -362,7 +367,12 @@ def RWKV_loader(model_name):
|
|||||||
'''
|
'''
|
||||||
from modules.RWKV import RWKVModel, RWKVTokenizer
|
from modules.RWKV import RWKVModel, RWKVTokenizer
|
||||||
|
|
||||||
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
|
model = RWKVModel.from_pretrained(
|
||||||
|
Path(f'{shared.args.model_dir}/{model_name}'),
|
||||||
|
dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16",
|
||||||
|
device="cpu" if shared.args.cpu else "xpu" if is_xpu_available() else "cuda"
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
|
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -380,6 +390,9 @@ def get_max_memory_dict():
|
|||||||
# If --auto-devices is provided standalone, try to get a reasonable value
|
# If --auto-devices is provided standalone, try to get a reasonable value
|
||||||
# for the maximum memory of device :0
|
# for the maximum memory of device :0
|
||||||
elif shared.args.auto_devices:
|
elif shared.args.auto_devices:
|
||||||
|
if is_xpu_available():
|
||||||
|
total_mem = (torch.xpu.get_device_properties(0).total_memory / (1024 * 1024))
|
||||||
|
else:
|
||||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||||
if total_mem - suggestion < 800:
|
if total_mem - suggestion < 800:
|
||||||
@ -395,6 +408,9 @@ def get_max_memory_dict():
|
|||||||
def clear_torch_cache():
|
def clear_torch_cache():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
|
if is_xpu_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
else:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LogitsWarper
|
from transformers import LogitsWarper, is_torch_xpu_available
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitNormalization,
|
LogitNormalization,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
@ -106,8 +106,11 @@ class MirostatLogitsWarper(LogitsWarper):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Normalize the probabilities of the remaining words
|
# Normalize the probabilities of the remaining words
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
prob_topk = torch.softmax(sorted_logits, dim=0).to("xpu")
|
||||||
|
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to("xpu")
|
||||||
|
else:
|
||||||
prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda')
|
prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda')
|
||||||
|
|
||||||
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')
|
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')
|
||||||
|
|
||||||
observed_surprise = -math.log2(prob_topk[prev_i])
|
observed_surprise = -math.log2(prob_topk[prev_i])
|
||||||
|
@ -9,7 +9,7 @@ import traceback
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LogitsProcessorList
|
from transformers import LogitsProcessorList, is_torch_xpu_available
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.callbacks import (
|
from modules.callbacks import (
|
||||||
@ -132,8 +132,8 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
device = torch.device('mps')
|
device = torch.device('mps')
|
||||||
return input_ids.to(device)
|
return input_ids.to(device)
|
||||||
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
elif is_torch_xpu_available():
|
||||||
return input_ids.to('xpu')
|
return input_ids.to("xpu:0")
|
||||||
else:
|
else:
|
||||||
return input_ids.cuda()
|
return input_ids.cuda()
|
||||||
|
|
||||||
@ -238,7 +238,8 @@ def set_manual_seed(seed):
|
|||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
elif is_torch_xpu_available():
|
||||||
|
torch.xpu.manual_seed_all(seed)
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from peft import (
|
|||||||
)
|
)
|
||||||
from peft.utils.other import \
|
from peft.utils.other import \
|
||||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules
|
||||||
|
from transformers import is_torch_xpu_available
|
||||||
from transformers.models.auto.modeling_auto import (
|
from transformers.models.auto.modeling_auto import (
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
@ -626,6 +627,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None,
|
ddp_find_unused_parameters=None,
|
||||||
no_cuda=shared.args.cpu,
|
no_cuda=shared.args.cpu,
|
||||||
|
use_ipex=True if is_torch_xpu_available and not shared.args.cpu else False
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||||
callbacks=list([Callbacks()])
|
callbacks=list([Callbacks()])
|
||||||
|
@ -4,10 +4,10 @@ from pathlib import Path
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
from transformers import is_torch_xpu_available
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f:
|
||||||
css = f.read()
|
css = f.read()
|
||||||
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
||||||
@ -85,7 +85,10 @@ def list_model_elements():
|
|||||||
'rope_freq_base',
|
'rope_freq_base',
|
||||||
'numa',
|
'numa',
|
||||||
]
|
]
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
for i in range(torch.xpu.device_count()):
|
||||||
|
elements.append(f'gpu_memory_{i}')
|
||||||
|
else:
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
elements.append(f'gpu_memory_{i}')
|
elements.append(f'gpu_memory_{i}')
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import is_torch_xpu_available
|
||||||
|
|
||||||
from modules import loaders, shared, ui, utils
|
from modules import loaders, shared, ui, utils
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
@ -27,6 +28,10 @@ def create_ui():
|
|||||||
|
|
||||||
# Finding the default values for the GPU and CPU memories
|
# Finding the default values for the GPU and CPU memories
|
||||||
total_mem = []
|
total_mem = []
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
for i in range(torch.xpu.device_count()):
|
||||||
|
total_mem.append(math.floor(torch.xpu.get_device_properties(i).total_memory / (1024 * 1024)))
|
||||||
|
else:
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
|
total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
|
||||||
|
|
||||||
|
13
one_click.py
13
one_click.py
@ -56,6 +56,19 @@ def cpu_has_avx2():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def cpu_has_amx():
|
||||||
|
try:
|
||||||
|
import cpuinfo
|
||||||
|
|
||||||
|
info = cpuinfo.get_cpu_info()
|
||||||
|
if 'amx' in info['flags']:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def torch_version():
|
def torch_version():
|
||||||
site_packages_path = None
|
site_packages_path = None
|
||||||
for sitedir in site.getsitepackages():
|
for sitedir in site.getsitepackages():
|
||||||
|
Loading…
Reference in New Issue
Block a user