From 77abd9b69bb089a9238b05a6889acc18fbfc053f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 2 Nov 2023 08:19:42 -0700 Subject: [PATCH] Add no_flash_attn option --- README.md | 1 + modules/exllamav2.py | 1 + modules/exllamav2_hf.py | 1 + modules/shared.py | 1 + 4 files changed, 4 insertions(+) diff --git a/README.md b/README.md index 652f654d..ad3736a4 100644 --- a/README.md +++ b/README.md @@ -336,6 +336,7 @@ Optionally, you can use the following command-line flags: |`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7. | |`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. | |`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. | +|`--no_flash_attn` | Force flash-attention to not be used. | #### AutoGPTQ diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 558c2365..3f3b3587 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -46,6 +46,7 @@ class Exllamav2Model: config.max_seq_len = shared.args.max_seq_len config.scale_pos_emb = shared.args.compress_pos_emb config.scale_alpha_value = shared.args.alpha_value + config.no_flash_attn = shared.args.no_flash_attn model = ExLlamaV2(config) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 70d64200..5d4aa515 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -152,5 +152,6 @@ class Exllamav2HF(PreTrainedModel): config.max_seq_len = shared.args.max_seq_len config.scale_pos_emb = shared.args.compress_pos_emb config.scale_alpha_value = shared.args.alpha_value + config.no_flash_attn = shared.args.no_flash_attn return Exllamav2HF(config) diff --git a/modules/shared.py b/modules/shared.py index 626c2bf8..e1da167f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -117,6 +117,7 @@ parser.add_argument('--cache-capacity', type=str, help='Maximum cache capacity ( parser.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.') parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.') parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.') +parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') # AutoGPTQ parser.add_argument('--triton', action='store_true', help='Use triton.')