mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-09 12:09:04 +01:00
Fix CUDA error on MPS backend during API request (#6572)
--------- Co-authored-by: oobabooga <oobabooga4@gmail.com>
This commit is contained in:
parent
979e1f1bd6
commit
13c033c745
@ -1,11 +1,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import is_torch_xpu_available
|
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import reload_model
|
from modules.models import get_device, reload_model
|
||||||
|
|
||||||
|
|
||||||
def add_lora_to_model(lora_names):
|
def add_lora_to_model(lora_names):
|
||||||
@ -132,14 +129,9 @@ def add_lora_transformers(lora_names):
|
|||||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||||
shared.model.half()
|
shared.model.half()
|
||||||
if not hasattr(shared.model, "hf_device_map"):
|
if not hasattr(shared.model, "hf_device_map"):
|
||||||
if torch.backends.mps.is_available():
|
device = get_device()
|
||||||
device = torch.device('mps')
|
if device:
|
||||||
shared.model = shared.model.to(device)
|
shared.model = shared.model.to(device)
|
||||||
elif is_torch_xpu_available():
|
|
||||||
device = torch.device("xpu:0")
|
|
||||||
shared.model = shared.model.to(device)
|
|
||||||
else:
|
|
||||||
shared.model = shared.model.cuda()
|
|
||||||
|
|
||||||
shared.lora_names = lora_names
|
shared.lora_names = lora_names
|
||||||
|
|
||||||
|
@ -2,11 +2,10 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import is_torch_npu_available, is_torch_xpu_available
|
|
||||||
|
|
||||||
from modules import models, sampler_hijack, shared
|
from modules import models, sampler_hijack, shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import load_model
|
from modules.models import get_device, load_model
|
||||||
from modules.text_generation import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
global_scores = None
|
global_scores = None
|
||||||
@ -57,23 +56,21 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
|||||||
scores = sampler_hijack.global_scores[-1]
|
scores = sampler_hijack.global_scores[-1]
|
||||||
else:
|
else:
|
||||||
if is_non_hf_exllamav2:
|
if is_non_hf_exllamav2:
|
||||||
if is_torch_xpu_available():
|
device = get_device()
|
||||||
tokens = shared.tokenizer.encode(prompt).to("xpu:0")
|
tokens = shared.tokenizer.encode(prompt)
|
||||||
elif is_torch_npu_available():
|
if device:
|
||||||
tokens = shared.tokenizer.encode(prompt).to("npu:0")
|
tokens = tokens.to(device)
|
||||||
else:
|
|
||||||
tokens = shared.tokenizer.encode(prompt).cuda()
|
|
||||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||||
elif is_non_hf_llamacpp:
|
elif is_non_hf_llamacpp:
|
||||||
tokens = shared.tokenizer.encode(prompt)
|
tokens = shared.tokenizer.encode(prompt)
|
||||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||||
else:
|
else:
|
||||||
if is_torch_xpu_available():
|
device = get_device()
|
||||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0")
|
tokens = shared.tokenizer.encode(prompt, return_tensors='pt')
|
||||||
elif is_torch_npu_available():
|
if device:
|
||||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("npu:0")
|
tokens = tokens.to(device)
|
||||||
else:
|
|
||||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
|
||||||
output = shared.model(input_ids=tokens)
|
output = shared.model(input_ids=tokens)
|
||||||
scores = output['logits'][-1][-1]
|
scores = output['logits'][-1][-1]
|
||||||
|
|
||||||
|
@ -21,11 +21,12 @@ from transformers import (
|
|||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GPTQConfig
|
GPTQConfig,
|
||||||
|
is_torch_npu_available,
|
||||||
|
is_torch_xpu_available
|
||||||
)
|
)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sampler_hijack
|
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models_settings import get_model_metadata
|
from modules.models_settings import get_model_metadata
|
||||||
|
|
||||||
@ -56,8 +57,6 @@ if shared.args.deepspeed:
|
|||||||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||||
|
|
||||||
sampler_hijack.hijack_samplers()
|
|
||||||
|
|
||||||
|
|
||||||
last_generation_time = time.time()
|
last_generation_time = time.time()
|
||||||
|
|
||||||
@ -172,17 +171,9 @@ def huggingface_loader(model_name):
|
|||||||
|
|
||||||
model = LoaderClass.from_pretrained(path_to_model, **params)
|
model = LoaderClass.from_pretrained(path_to_model, **params)
|
||||||
if not (hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit):
|
if not (hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit):
|
||||||
if torch.backends.mps.is_available():
|
device = get_device()
|
||||||
device = torch.device('mps')
|
if device:
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
elif is_xpu_available():
|
|
||||||
device = torch.device("xpu")
|
|
||||||
model = model.to(device)
|
|
||||||
elif is_npu_available():
|
|
||||||
device = torch.device("npu")
|
|
||||||
model = model.to(device)
|
|
||||||
else:
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
# DeepSpeed ZeRO-3
|
# DeepSpeed ZeRO-3
|
||||||
elif shared.args.deepspeed:
|
elif shared.args.deepspeed:
|
||||||
@ -380,13 +371,34 @@ def get_max_memory_dict():
|
|||||||
return max_memory if len(max_memory) > 0 else None
|
return max_memory if len(max_memory) > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device('cuda')
|
||||||
|
elif shared.args.deepspeed:
|
||||||
|
import deepspeed
|
||||||
|
return deepspeed.get_accelerator().current_device_name()
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return torch.device('mps')
|
||||||
|
elif is_torch_xpu_available():
|
||||||
|
return torch.device('xpu:0')
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
return torch.device('npu:0')
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def clear_torch_cache():
|
def clear_torch_cache():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
if is_xpu_available():
|
if torch.cuda.is_available():
|
||||||
torch.xpu.empty_cache()
|
|
||||||
else:
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif is_xpu_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
elif is_npu_available():
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
if hasattr(torch.backends.mps, 'empty_cache'):
|
||||||
|
torch.backends.mps.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def unload_model(keep_model_name=False):
|
def unload_model(keep_model_name=False):
|
||||||
|
@ -5,7 +5,7 @@ import random
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LogitsWarper, is_torch_xpu_available
|
from transformers import LogitsWarper
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitNormalization,
|
LogitNormalization,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
@ -14,6 +14,7 @@ from transformers.generation.logits_process import (
|
|||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.models import get_device
|
||||||
|
|
||||||
global_scores = None
|
global_scores = None
|
||||||
|
|
||||||
@ -339,12 +340,12 @@ class MirostatLogitsWarper(LogitsWarper):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Normalize the probabilities of the remaining words
|
# Normalize the probabilities of the remaining words
|
||||||
if is_torch_xpu_available():
|
prob_topk = torch.softmax(sorted_logits, dim=0)
|
||||||
prob_topk = torch.softmax(sorted_logits, dim=0).to("xpu")
|
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
|
||||||
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to("xpu")
|
device = get_device()
|
||||||
else:
|
if device:
|
||||||
prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda')
|
prob_topk = prob_topk.to(device)
|
||||||
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')
|
prev_i = prev_i.to(device)
|
||||||
|
|
||||||
observed_surprise = -math.log2(prob_topk[prev_i])
|
observed_surprise = -math.log2(prob_topk[prev_i])
|
||||||
self.e = observed_surprise - self.mirostat_tau
|
self.e = observed_surprise - self.mirostat_tau
|
||||||
|
@ -16,7 +16,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import models
|
from modules import models, sampler_hijack
|
||||||
from modules.cache_utils import process_llamacpp_cache
|
from modules.cache_utils import process_llamacpp_cache
|
||||||
from modules.callbacks import (
|
from modules.callbacks import (
|
||||||
Iteratorize,
|
Iteratorize,
|
||||||
@ -28,7 +28,9 @@ from modules.grammar.grammar_utils import initialize_grammar
|
|||||||
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
||||||
from modules.html_generator import generate_basic_html
|
from modules.html_generator import generate_basic_html
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import clear_torch_cache, load_model
|
from modules.models import clear_torch_cache, get_device, load_model
|
||||||
|
|
||||||
|
sampler_hijack.hijack_samplers()
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(*args, **kwargs):
|
def generate_reply(*args, **kwargs):
|
||||||
@ -159,18 +161,12 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||||
return input_ids
|
return input_ids
|
||||||
elif shared.args.deepspeed:
|
|
||||||
import deepspeed
|
|
||||||
return input_ids.to(deepspeed.get_accelerator().current_device_name())
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
device = torch.device('mps')
|
|
||||||
return input_ids.to(device)
|
|
||||||
elif is_torch_xpu_available():
|
|
||||||
return input_ids.to("xpu:0")
|
|
||||||
elif is_torch_npu_available():
|
|
||||||
return input_ids.to("npu:0")
|
|
||||||
else:
|
else:
|
||||||
return input_ids.cuda()
|
device = get_device()
|
||||||
|
if device:
|
||||||
|
return input_ids.to(device)
|
||||||
|
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
def decode(output_ids, skip_special_tokens=True):
|
def decode(output_ids, skip_special_tokens=True):
|
||||||
@ -328,7 +324,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
# Encode the input
|
# Encode the input
|
||||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
|
||||||
if state['auto_max_new_tokens']:
|
if state['auto_max_new_tokens']:
|
||||||
generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1]
|
generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1]
|
||||||
|
|
||||||
@ -383,8 +378,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
if not state['stream']:
|
if not state['stream']:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = shared.model.generate(**generate_params)[0]
|
output = shared.model.generate(**generate_params)[0]
|
||||||
if cuda:
|
device = get_device()
|
||||||
output = output.cuda()
|
if device:
|
||||||
|
output = output.to(device)
|
||||||
|
|
||||||
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
||||||
yield get_reply_from_output_ids(output, state, starting_from=starting_from)
|
yield get_reply_from_output_ids(output, state, starting_from=starting_from)
|
||||||
|
Loading…
Reference in New Issue
Block a user