diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 9d2a0fc4..2758fa2d 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -13,6 +13,17 @@ from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler from modules import shared from modules.text_generation import get_max_prompt_length +try: + import flash_attn +except ModuleNotFoundError: + logger.warning( + 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' + 'to be a lot higher than it could be.\n' + 'Try installing flash-attention following the instructions here: ' + 'https://github.com/Dao-AILab/flash-attention#installation-and-features' + ) + pass + class Exllamav2Model: def __init__(self): diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index a8837af1..457942ac 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -11,6 +11,17 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from modules import shared from modules.logging_colors import logger +try: + import flash_attn +except ModuleNotFoundError: + logger.warning( + 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' + 'to be a lot higher than it could be.\n' + 'Try installing flash-attention following the instructions here: ' + 'https://github.com/Dao-AILab/flash-attention#installation-and-features' + ) + pass + class Exllamav2HF(PreTrainedModel): def __init__(self, config: ExLlamaV2Config):