mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Make llama-cpp-python not crash immediately
This commit is contained in:
parent
f77cf159ba
commit
f243b4ca9c
@ -1,3 +1,4 @@
|
|||||||
|
import importlib
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -5,20 +6,38 @@ from tqdm import tqdm
|
|||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.cache_utils import process_llamacpp_cache
|
from modules.cache_utils import process_llamacpp_cache
|
||||||
|
|
||||||
try:
|
|
||||||
import llama_cpp
|
|
||||||
except:
|
|
||||||
llama_cpp = None
|
|
||||||
|
|
||||||
try:
|
def llama_cpp_lib():
|
||||||
import llama_cpp_cuda
|
return_lib = None
|
||||||
except:
|
|
||||||
llama_cpp_cuda = None
|
|
||||||
|
|
||||||
try:
|
if shared.args.cpu:
|
||||||
import llama_cpp_cuda_tensorcores
|
try:
|
||||||
except:
|
return_lib = importlib.import_module('llama_cpp')
|
||||||
llama_cpp_cuda_tensorcores = None
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if shared.args.tensorcores and return_lib is None:
|
||||||
|
try:
|
||||||
|
return_lib = importlib.import_module('llama_cpp_cuda_tensorcores')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if return_lib is None:
|
||||||
|
try:
|
||||||
|
return_lib = importlib.import_module('llama_cpp_cuda')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if return_lib is None and not shared.args.cpu:
|
||||||
|
try:
|
||||||
|
return_lib = importlib.import_module('llama_cpp')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if return_lib is not None:
|
||||||
|
monkey_patch_llama_cpp_python(return_lib)
|
||||||
|
|
||||||
|
return return_lib
|
||||||
|
|
||||||
|
|
||||||
def eval_with_progress(self, tokens: Sequence[int]):
|
def eval_with_progress(self, tokens: Sequence[int]):
|
||||||
@ -63,7 +82,7 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
|||||||
self.n_tokens += n_tokens
|
self.n_tokens += n_tokens
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_generate(lib):
|
def monkey_patch_llama_cpp_python(lib):
|
||||||
|
|
||||||
def my_generate(self, *args, **kwargs):
|
def my_generate(self, *args, **kwargs):
|
||||||
|
|
||||||
@ -77,11 +96,6 @@ def monkey_patch_generate(lib):
|
|||||||
for output in self.original_generate(*args, **kwargs):
|
for output in self.original_generate(*args, **kwargs):
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
lib.Llama.eval = eval_with_progress
|
||||||
lib.Llama.original_generate = lib.Llama.generate
|
lib.Llama.original_generate = lib.Llama.generate
|
||||||
lib.Llama.generate = my_generate
|
lib.Llama.generate = my_generate
|
||||||
|
|
||||||
|
|
||||||
for lib in [llama_cpp, llama_cpp_cuda, llama_cpp_cuda_tensorcores]:
|
|
||||||
if lib is not None:
|
|
||||||
lib.Llama.eval = eval_with_progress
|
|
||||||
monkey_patch_generate(lib)
|
|
||||||
|
@ -7,35 +7,10 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
from modules import llama_cpp_python_hijack, shared
|
from modules import shared
|
||||||
|
from modules.llama_cpp_python_hijack import llama_cpp_lib
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
try:
|
|
||||||
import llama_cpp
|
|
||||||
except:
|
|
||||||
llama_cpp = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
import llama_cpp_cuda
|
|
||||||
except:
|
|
||||||
llama_cpp_cuda = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
import llama_cpp_cuda_tensorcores
|
|
||||||
except:
|
|
||||||
llama_cpp_cuda_tensorcores = None
|
|
||||||
|
|
||||||
|
|
||||||
def llama_cpp_lib():
|
|
||||||
if shared.args.cpu and llama_cpp is not None:
|
|
||||||
return llama_cpp
|
|
||||||
elif shared.args.tensorcores and llama_cpp_cuda_tensorcores is not None:
|
|
||||||
return llama_cpp_cuda_tensorcores
|
|
||||||
elif llama_cpp_cuda is not None:
|
|
||||||
return llama_cpp_cuda
|
|
||||||
else:
|
|
||||||
return llama_cpp
|
|
||||||
|
|
||||||
|
|
||||||
class LlamacppHF(PreTrainedModel):
|
class LlamacppHF(PreTrainedModel):
|
||||||
def __init__(self, model, path):
|
def __init__(self, model, path):
|
||||||
|
@ -4,37 +4,12 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import llama_cpp_python_hijack, shared
|
from modules import shared
|
||||||
from modules.callbacks import Iteratorize
|
from modules.callbacks import Iteratorize
|
||||||
|
from modules.llama_cpp_python_hijack import llama_cpp_lib
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import get_max_prompt_length
|
from modules.text_generation import get_max_prompt_length
|
||||||
|
|
||||||
try:
|
|
||||||
import llama_cpp
|
|
||||||
except:
|
|
||||||
llama_cpp = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
import llama_cpp_cuda
|
|
||||||
except:
|
|
||||||
llama_cpp_cuda = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
import llama_cpp_cuda_tensorcores
|
|
||||||
except:
|
|
||||||
llama_cpp_cuda_tensorcores = None
|
|
||||||
|
|
||||||
|
|
||||||
def llama_cpp_lib():
|
|
||||||
if shared.args.cpu and llama_cpp is not None:
|
|
||||||
return llama_cpp
|
|
||||||
elif shared.args.tensorcores and llama_cpp_cuda_tensorcores is not None:
|
|
||||||
return llama_cpp_cuda_tensorcores
|
|
||||||
elif llama_cpp_cuda is not None:
|
|
||||||
return llama_cpp_cuda
|
|
||||||
else:
|
|
||||||
return llama_cpp
|
|
||||||
|
|
||||||
|
|
||||||
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
||||||
logits[eos_token] = -float('inf')
|
logits[eos_token] = -float('inf')
|
||||||
|
Loading…
Reference in New Issue
Block a user