Add a warning about ExLlamaV2 without flash-attn

This commit is contained in:
oobabooga 2023-09-18 12:25:17 -07:00
parent f0ef971edb
commit 605ec3c9f2
2 changed files with 22 additions and 0 deletions

View File

@ -13,6 +13,17 @@ from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler
from modules import shared from modules import shared
from modules.text_generation import get_max_prompt_length from modules.text_generation import get_max_prompt_length
try:
import flash_attn
except ModuleNotFoundError:
logger.warning(
'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
'to be a lot higher than it could be.\n'
'Try installing flash-attention following the instructions here: '
'https://github.com/Dao-AILab/flash-attention#installation-and-features'
)
pass
class Exllamav2Model: class Exllamav2Model:
def __init__(self): def __init__(self):

View File

@ -11,6 +11,17 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
try:
import flash_attn
except ModuleNotFoundError:
logger.warning(
'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
'to be a lot higher than it could be.\n'
'Try installing flash-attention following the instructions here: '
'https://github.com/Dao-AILab/flash-attention#installation-and-features'
)
pass
class Exllamav2HF(PreTrainedModel): class Exllamav2HF(PreTrainedModel):
def __init__(self, config: ExLlamaV2Config): def __init__(self, config: ExLlamaV2Config):