transformers: add use_flash_attention_2 option (#4373)

This commit is contained in:
feng lui 2023-11-05 00:59:33 +08:00 committed by GitHub
parent add359379e
commit 4766a57352
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 9 additions and 1 deletions

View File

@ -300,6 +300,7 @@ Optionally, you can use the following command-line flags:
| `--sdp-attention` | Use PyTorch 2.0's SDP attention. Same as above. | | `--sdp-attention` | Use PyTorch 2.0's SDP attention. Same as above. |
| `--trust-remote-code` | Set `trust_remote_code=True` while loading the model. Necessary for some models. | | `--trust-remote-code` | Set `trust_remote_code=True` while loading the model. Necessary for some models. |
| `--use_fast` | Set `use_fast=True` while loading the tokenizer. | | `--use_fast` | Set `use_fast=True` while loading the tokenizer. |
| `--use_flash_attention_2` | Set use_flash_attention_2=True while loading the model. |
#### Accelerate 4-bit #### Accelerate 4-bit

View File

@ -9,7 +9,6 @@ loaders_and_params = OrderedDict({
'Transformers': [ 'Transformers': [
'cpu_memory', 'cpu_memory',
'gpu_memory', 'gpu_memory',
'trust_remote_code',
'load_in_8bit', 'load_in_8bit',
'bf16', 'bf16',
'cpu', 'cpu',
@ -21,6 +20,7 @@ loaders_and_params = OrderedDict({
'compute_dtype', 'compute_dtype',
'trust_remote_code', 'trust_remote_code',
'use_fast', 'use_fast',
'use_flash_attention_2',
'alpha_value', 'alpha_value',
'rope_freq_base', 'rope_freq_base',
'compress_pos_emb', 'compress_pos_emb',

View File

@ -126,6 +126,10 @@ def huggingface_loader(model_name):
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16, 'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
'use_safetensors': True if shared.args.force_safetensors else None 'use_safetensors': True if shared.args.force_safetensors else None
} }
if shared.args.use_flash_attention_2:
params['use_flash_attention_2'] = True
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code']) config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code'])
if 'chatglm' in model_name.lower(): if 'chatglm' in model_name.lower():

View File

@ -93,6 +93,7 @@ parser.add_argument('--sdp-attention', action='store_true', help='Use PyTorch 2.
parser.add_argument('--trust-remote-code', action='store_true', help='Set trust_remote_code=True while loading the model. Necessary for some models.') parser.add_argument('--trust-remote-code', action='store_true', help='Set trust_remote_code=True while loading the model. Necessary for some models.')
parser.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.') parser.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.')
parser.add_argument('--use_fast', action='store_true', help='Set use_fast=True while loading the tokenizer.') parser.add_argument('--use_fast', action='store_true', help='Set use_fast=True while loading the tokenizer.')
parser.add_argument('--use_flash_attention_2', action='store_true', help='Set use_flash_attention_2=True while loading the model.')
# Accelerate 4-bit # Accelerate 4-bit
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).') parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).')

View File

@ -53,6 +53,7 @@ def list_model_elements():
'load_in_8bit', 'load_in_8bit',
'trust_remote_code', 'trust_remote_code',
'use_fast', 'use_fast',
'use_flash_attention_2',
'load_in_4bit', 'load_in_4bit',
'compute_dtype', 'compute_dtype',
'quant_type', 'quant_type',

View File

@ -124,6 +124,7 @@ def create_ui():
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed) shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code) shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code)
shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.') shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.')
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.') shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.')
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')