From 13c033c745b54ce3b35c805ca0ff12f8f93bfb5d Mon Sep 17 00:00:00 2001 From: Petr Korolev Date: Thu, 2 Jan 2025 06:06:11 +0300 Subject: [PATCH] Fix CUDA error on MPS backend during API request (#6572) --------- Co-authored-by: oobabooga --- modules/LoRA.py | 14 +++--------- modules/logits.py | 25 +++++++++------------ modules/models.py | 46 ++++++++++++++++++++++++-------------- modules/sampler_hijack.py | 15 +++++++------ modules/text_generation.py | 28 ++++++++++------------- 5 files changed, 63 insertions(+), 65 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index 4fd144ba..e1ad01d7 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,11 +1,8 @@ from pathlib import Path -import torch -from transformers import is_torch_xpu_available - import modules.shared as shared from modules.logging_colors import logger -from modules.models import reload_model +from modules.models import get_device, reload_model def add_lora_to_model(lora_names): @@ -132,14 +129,9 @@ def add_lora_transformers(lora_names): if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): - if torch.backends.mps.is_available(): - device = torch.device('mps') + device = get_device() + if device: shared.model = shared.model.to(device) - elif is_torch_xpu_available(): - device = torch.device("xpu:0") - shared.model = shared.model.to(device) - else: - shared.model = shared.model.cuda() shared.lora_names = lora_names diff --git a/modules/logits.py b/modules/logits.py index 73cabb41..f8a1e80c 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -2,11 +2,10 @@ import time import traceback import torch -from transformers import is_torch_npu_available, is_torch_xpu_available from modules import models, sampler_hijack, shared from modules.logging_colors import logger -from modules.models import load_model +from modules.models import get_device, load_model from modules.text_generation import generate_reply global_scores = None @@ -57,23 +56,21 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur scores = sampler_hijack.global_scores[-1] else: if is_non_hf_exllamav2: - if is_torch_xpu_available(): - tokens = shared.tokenizer.encode(prompt).to("xpu:0") - elif is_torch_npu_available(): - tokens = shared.tokenizer.encode(prompt).to("npu:0") - else: - tokens = shared.tokenizer.encode(prompt).cuda() + device = get_device() + tokens = shared.tokenizer.encode(prompt) + if device: + tokens = tokens.to(device) + scores = shared.model.get_logits(tokens)[-1][-1] elif is_non_hf_llamacpp: tokens = shared.tokenizer.encode(prompt) scores = shared.model.get_logits(tokens)[-1][-1] else: - if is_torch_xpu_available(): - tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0") - elif is_torch_npu_available(): - tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("npu:0") - else: - tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() + device = get_device() + tokens = shared.tokenizer.encode(prompt, return_tensors='pt') + if device: + tokens = tokens.to(device) + output = shared.model(input_ids=tokens) scores = output['logits'][-1][-1] diff --git a/modules/models.py b/modules/models.py index 7a52c07c..d906535b 100644 --- a/modules/models.py +++ b/modules/models.py @@ -21,11 +21,12 @@ from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig, - GPTQConfig + GPTQConfig, + is_torch_npu_available, + is_torch_xpu_available ) import modules.shared as shared -from modules import sampler_hijack from modules.logging_colors import logger from modules.models_settings import get_model_metadata @@ -56,8 +57,6 @@ if shared.args.deepspeed: ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir) dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration -sampler_hijack.hijack_samplers() - last_generation_time = time.time() @@ -172,17 +171,9 @@ def huggingface_loader(model_name): model = LoaderClass.from_pretrained(path_to_model, **params) if not (hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit): - if torch.backends.mps.is_available(): - device = torch.device('mps') + device = get_device() + if device: model = model.to(device) - elif is_xpu_available(): - device = torch.device("xpu") - model = model.to(device) - elif is_npu_available(): - device = torch.device("npu") - model = model.to(device) - else: - model = model.cuda() # DeepSpeed ZeRO-3 elif shared.args.deepspeed: @@ -380,13 +371,34 @@ def get_max_memory_dict(): return max_memory if len(max_memory) > 0 else None +def get_device(): + if torch.cuda.is_available(): + return torch.device('cuda') + elif shared.args.deepspeed: + import deepspeed + return deepspeed.get_accelerator().current_device_name() + elif torch.backends.mps.is_available(): + return torch.device('mps') + elif is_torch_xpu_available(): + return torch.device('xpu:0') + elif is_torch_npu_available(): + return torch.device('npu:0') + else: + return None + + def clear_torch_cache(): gc.collect() if not shared.args.cpu: - if is_xpu_available(): - torch.xpu.empty_cache() - else: + if torch.cuda.is_available(): torch.cuda.empty_cache() + elif is_xpu_available(): + torch.xpu.empty_cache() + elif is_npu_available(): + torch.npu.empty_cache() + elif torch.backends.mps.is_available(): + if hasattr(torch.backends.mps, 'empty_cache'): + torch.backends.mps.empty_cache() def unload_model(keep_model_name=False): diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 24dbcf2e..62ceca8d 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -5,7 +5,7 @@ import random import torch import transformers -from transformers import LogitsWarper, is_torch_xpu_available +from transformers import LogitsWarper from transformers.generation.logits_process import ( LogitNormalization, LogitsProcessor, @@ -14,6 +14,7 @@ from transformers.generation.logits_process import ( from modules import shared from modules.logging_colors import logger +from modules.models import get_device global_scores = None @@ -339,12 +340,12 @@ class MirostatLogitsWarper(LogitsWarper): break # Normalize the probabilities of the remaining words - if is_torch_xpu_available(): - prob_topk = torch.softmax(sorted_logits, dim=0).to("xpu") - prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to("xpu") - else: - prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda') - prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda') + prob_topk = torch.softmax(sorted_logits, dim=0) + prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True) + device = get_device() + if device: + prob_topk = prob_topk.to(device) + prev_i = prev_i.to(device) observed_surprise = -math.log2(prob_topk[prev_i]) self.e = observed_surprise - self.mirostat_tau diff --git a/modules/text_generation.py b/modules/text_generation.py index c999fa81..db415dce 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -16,7 +16,7 @@ from transformers import ( ) import modules.shared as shared -from modules import models +from modules import models, sampler_hijack from modules.cache_utils import process_llamacpp_cache from modules.callbacks import ( Iteratorize, @@ -28,7 +28,9 @@ from modules.grammar.grammar_utils import initialize_grammar from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor from modules.html_generator import generate_basic_html from modules.logging_colors import logger -from modules.models import clear_torch_cache, load_model +from modules.models import clear_torch_cache, get_device, load_model + +sampler_hijack.hijack_samplers() def generate_reply(*args, **kwargs): @@ -159,18 +161,12 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu: return input_ids - elif shared.args.deepspeed: - import deepspeed - return input_ids.to(deepspeed.get_accelerator().current_device_name()) - elif torch.backends.mps.is_available(): - device = torch.device('mps') - return input_ids.to(device) - elif is_torch_xpu_available(): - return input_ids.to("xpu:0") - elif is_torch_npu_available(): - return input_ids.to("npu:0") else: - return input_ids.cuda() + device = get_device() + if device: + return input_ids.to(device) + + return input_ids def decode(output_ids, skip_special_tokens=True): @@ -328,7 +324,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings # Encode the input input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) output = input_ids[0] - cuda = not any((shared.args.cpu, shared.args.deepspeed)) if state['auto_max_new_tokens']: generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1] @@ -383,8 +378,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if not state['stream']: with torch.no_grad(): output = shared.model.generate(**generate_params)[0] - if cuda: - output = output.cuda() + device = get_device() + if device: + output = output.to(device) starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) yield get_reply_from_output_ids(output, state, starting_from=starting_from)