Add --num_experts_per_token parameter (ExLlamav2) (#4955)

This commit is contained in:
oobabooga 2023-12-17 12:08:33 -03:00 committed by GitHub
parent 12690d3ffc
commit f1f2c4c3f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 28 additions and 20 deletions

View File

@ -274,6 +274,7 @@ List of command-line flags
|`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. | |`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. |
|`--no_flash_attn` | Force flash-attention to not be used. | |`--no_flash_attn` | Force flash-attention to not be used. |
|`--cache_8bit` | Use 8-bit cache to save VRAM. | |`--cache_8bit` | Use 8-bit cache to save VRAM. |
|`--num_experts_per_token NUM_EXPERTS_PER_TOKEN` | Number of experts to use for generation. Applies to MoE models like Mixtral. |
#### AutoGPTQ #### AutoGPTQ

View File

@ -48,6 +48,7 @@ class Exllamav2Model:
config.scale_pos_emb = shared.args.compress_pos_emb config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn config.no_flash_attn = shared.args.no_flash_attn
config.num_experts_per_token = int(shared.args.num_experts_per_token)
model = ExLlamaV2(config) model = ExLlamaV2(config)

View File

@ -165,5 +165,6 @@ class Exllamav2HF(PreTrainedModel):
config.scale_pos_emb = shared.args.compress_pos_emb config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn config.no_flash_attn = shared.args.no_flash_attn
config.num_experts_per_token = int(shared.args.num_experts_per_token)
return Exllamav2HF(config) return Exllamav2HF(config)

View File

@ -65,6 +65,18 @@ loaders_and_params = OrderedDict({
'logits_all', 'logits_all',
'llamacpp_HF_info', 'llamacpp_HF_info',
], ],
'ExLlamav2_HF': [
'gpu_split',
'max_seq_len',
'cfg_cache',
'no_flash_attn',
'num_experts_per_token',
'cache_8bit',
'alpha_value',
'compress_pos_emb',
'trust_remote_code',
'no_use_fast',
],
'ExLlama_HF': [ 'ExLlama_HF': [
'gpu_split', 'gpu_split',
'max_seq_len', 'max_seq_len',
@ -75,17 +87,6 @@ loaders_and_params = OrderedDict({
'trust_remote_code', 'trust_remote_code',
'no_use_fast', 'no_use_fast',
], ],
'ExLlamav2_HF': [
'gpu_split',
'max_seq_len',
'cfg_cache',
'no_flash_attn',
'cache_8bit',
'alpha_value',
'compress_pos_emb',
'trust_remote_code',
'no_use_fast',
],
'AutoGPTQ': [ 'AutoGPTQ': [
'triton', 'triton',
'no_inject_fused_attention', 'no_inject_fused_attention',
@ -123,6 +124,16 @@ loaders_and_params = OrderedDict({
'no_use_fast', 'no_use_fast',
'gptq_for_llama_info', 'gptq_for_llama_info',
], ],
'ExLlamav2': [
'gpu_split',
'max_seq_len',
'no_flash_attn',
'num_experts_per_token',
'cache_8bit',
'alpha_value',
'compress_pos_emb',
'exllamav2_info',
],
'ExLlama': [ 'ExLlama': [
'gpu_split', 'gpu_split',
'max_seq_len', 'max_seq_len',
@ -131,15 +142,6 @@ loaders_and_params = OrderedDict({
'compress_pos_emb', 'compress_pos_emb',
'exllama_info', 'exllama_info',
], ],
'ExLlamav2': [
'gpu_split',
'max_seq_len',
'no_flash_attn',
'cache_8bit',
'alpha_value',
'compress_pos_emb',
'exllamav2_info',
],
'ctransformers': [ 'ctransformers': [
'n_ctx', 'n_ctx',
'n_gpu_layers', 'n_gpu_layers',

View File

@ -125,6 +125,7 @@ parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum seque
parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.') parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.')
parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
parser.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
# AutoGPTQ # AutoGPTQ
parser.add_argument('--triton', action='store_true', help='Use triton.') parser.add_argument('--triton', action='store_true', help='Use triton.')

View File

@ -73,6 +73,7 @@ def list_model_elements():
'disable_exllamav2', 'disable_exllamav2',
'cfg_cache', 'cfg_cache',
'no_flash_attn', 'no_flash_attn',
'num_experts_per_token',
'cache_8bit', 'cache_8bit',
'threads', 'threads',
'threads_batch', 'threads_batch',

View File

@ -129,6 +129,7 @@ def create_ui():
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.') shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
shared.gradio['gptq_for_llama_info'] = gr.Markdown('Legacy loader for compatibility with older GPUs. ExLlama_HF or AutoGPTQ are preferred for GPTQ models when supported.') shared.gradio['gptq_for_llama_info'] = gr.Markdown('Legacy loader for compatibility with older GPUs. ExLlama_HF or AutoGPTQ are preferred for GPTQ models when supported.')
shared.gradio['exllama_info'] = gr.Markdown("ExLlama_HF is recommended over ExLlama for better integration with extensions and more consistent sampling behavior across loaders.") shared.gradio['exllama_info'] = gr.Markdown("ExLlama_HF is recommended over ExLlama for better integration with extensions and more consistent sampling behavior across loaders.")
shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.") shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.")