From 605ec3c9f22a5af4db7b37c1af5a0a1aefce1a65 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 18 Sep 2023 12:25:17 -0700 Subject: [PATCH] Add a warning about ExLlamaV2 without flash-attn --- modules/exllamav2.py | 11 +++++++++++ modules/exllamav2_hf.py | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 9d2a0fc4..2758fa2d 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -13,6 +13,17 @@ from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler from modules import shared from modules.text_generation import get_max_prompt_length +try: + import flash_attn +except ModuleNotFoundError: + logger.warning( + 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' + 'to be a lot higher than it could be.\n' + 'Try installing flash-attention following the instructions here: ' + 'https://github.com/Dao-AILab/flash-attention#installation-and-features' + ) + pass + class Exllamav2Model: def __init__(self): diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index a8837af1..457942ac 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -11,6 +11,17 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from modules import shared from modules.logging_colors import logger +try: + import flash_attn +except ModuleNotFoundError: + logger.warning( + 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage ' + 'to be a lot higher than it could be.\n' + 'Try installing flash-attention following the instructions here: ' + 'https://github.com/Dao-AILab/flash-attention#installation-and-features' + ) + pass + class Exllamav2HF(PreTrainedModel): def __init__(self, config: ExLlamaV2Config):