Add Ascend NPU support (basic) (#5541)

This commit is contained in:
wangshuai09 2024-04-12 05:42:20 +08:00 committed by GitHub
parent a90509d82e
commit fd4e46bce2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 35 additions and 7 deletions

View File

@ -5,7 +5,7 @@ from threading import Thread
import torch
import transformers
from transformers import is_torch_xpu_available
from transformers import is_torch_npu_available, is_torch_xpu_available
import modules.shared as shared
@ -99,5 +99,7 @@ def clear_torch_cache():
if not shared.args.cpu:
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()

View File

@ -1,5 +1,5 @@
import torch
from transformers import is_torch_xpu_available
from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import sampler_hijack, shared
from modules.logging_colors import logger
@ -34,6 +34,8 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
if is_non_hf_exllamav2:
if is_torch_xpu_available():
tokens = shared.tokenizer.encode(prompt).to("xpu:0")
elif is_torch_npu_available():
tokens = shared.tokenizer.encode(prompt).to("npu:0")
else:
tokens = shared.tokenizer.encode(prompt).cuda()
scores = shared.model.get_logits(tokens)[-1][-1]
@ -43,6 +45,8 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
else:
if is_torch_xpu_available():
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0")
elif is_torch_npu_available():
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("npu:0")
else:
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
output = shared.model(input_ids=tokens)

View File

@ -10,7 +10,11 @@ from pathlib import Path
import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import is_ccl_available, is_xpu_available
from accelerate.utils import (
is_ccl_available,
is_npu_available,
is_xpu_available
)
from transformers import (
AutoConfig,
AutoModel,
@ -45,6 +49,9 @@ if shared.args.deepspeed:
if is_xpu_available() and is_ccl_available():
torch.xpu.set_device(local_rank)
deepspeed.init_distributed(backend="ccl")
elif is_npu_available():
torch.npu.set_device(local_rank)
deepspeed.init_distributed(dist_backend="hccl")
else:
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
@ -164,6 +171,9 @@ def huggingface_loader(model_name):
elif is_xpu_available():
device = torch.device("xpu")
model = model.to(device)
elif is_npu_available():
device = torch.device("npu")
model = model.to(device)
else:
model = model.cuda()

View File

@ -10,7 +10,11 @@ import traceback
import numpy as np
import torch
import transformers
from transformers import LogitsProcessorList, is_torch_xpu_available
from transformers import (
LogitsProcessorList,
is_torch_npu_available,
is_torch_xpu_available
)
import modules.shared as shared
from modules.cache_utils import process_llamacpp_cache
@ -24,7 +28,7 @@ from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
from modules.html_generator import generate_basic_html
from modules.logging_colors import logger
from modules.models import clear_torch_cache, local_rank
from modules.models import clear_torch_cache
def generate_reply(*args, **kwargs):
@ -131,12 +135,15 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model'] or shared.args.cpu:
return input_ids
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)
import deepspeed
return input_ids.to(deepspeed.get_accelerator().current_device_name())
elif torch.backends.mps.is_available():
device = torch.device('mps')
return input_ids.to(device)
elif is_torch_xpu_available():
return input_ids.to("xpu:0")
elif is_torch_npu_available():
return input_ids.to("npu:0")
else:
return input_ids.cuda()
@ -213,6 +220,8 @@ def set_manual_seed(seed):
torch.cuda.manual_seed_all(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed_all(seed)
elif is_torch_npu_available():
torch.npu.manual_seed_all(seed)
return seed

View File

@ -8,7 +8,7 @@ from pathlib import Path
import gradio as gr
import psutil
import torch
from transformers import is_torch_xpu_available
from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import loaders, shared, ui, utils
from modules.logging_colors import logger
@ -32,6 +32,9 @@ def create_ui():
if is_torch_xpu_available():
for i in range(torch.xpu.device_count()):
total_mem.append(math.floor(torch.xpu.get_device_properties(i).total_memory / (1024 * 1024)))
elif is_torch_npu_available():
for i in range(torch.npu.device_count()):
total_mem.append(math.floor(torch.npu.get_device_properties(i).total_memory / (1024 * 1024)))
else:
for i in range(torch.cuda.device_count()):
total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))