mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add --autosplit flag for ExLlamaV2 (#5524)
This commit is contained in:
parent
4039999be5
commit
a6730f88f7
@ -51,18 +51,21 @@ class Exllamav2Model:
|
|||||||
|
|
||||||
model = ExLlamaV2(config)
|
model = ExLlamaV2(config)
|
||||||
|
|
||||||
split = None
|
if shared.args.cache_8bit:
|
||||||
if shared.args.gpu_split:
|
cache = ExLlamaV2Cache_8bit(model, lazy=True)
|
||||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
else:
|
||||||
|
cache = ExLlamaV2Cache(model, lazy=True)
|
||||||
|
|
||||||
model.load(split)
|
if shared.args.autosplit:
|
||||||
|
model.load_autosplit(cache)
|
||||||
|
else:
|
||||||
|
split = None
|
||||||
|
if shared.args.gpu_split:
|
||||||
|
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||||
|
|
||||||
|
model.load(split)
|
||||||
|
|
||||||
tokenizer = ExLlamaV2Tokenizer(config)
|
tokenizer = ExLlamaV2Tokenizer(config)
|
||||||
if shared.args.cache_8bit:
|
|
||||||
cache = ExLlamaV2Cache_8bit(model)
|
|
||||||
else:
|
|
||||||
cache = ExLlamaV2Cache(model)
|
|
||||||
|
|
||||||
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
|
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
|
||||||
|
|
||||||
result = self()
|
result = self()
|
||||||
|
@ -37,18 +37,22 @@ class Exllamav2HF(PreTrainedModel):
|
|||||||
super().__init__(PretrainedConfig())
|
super().__init__(PretrainedConfig())
|
||||||
self.ex_config = config
|
self.ex_config = config
|
||||||
self.ex_model = ExLlamaV2(config)
|
self.ex_model = ExLlamaV2(config)
|
||||||
split = None
|
|
||||||
if shared.args.gpu_split:
|
|
||||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
|
||||||
|
|
||||||
self.ex_model.load(split)
|
|
||||||
self.generation_config = GenerationConfig()
|
|
||||||
self.loras = None
|
self.loras = None
|
||||||
|
self.generation_config = GenerationConfig()
|
||||||
|
|
||||||
if shared.args.cache_8bit:
|
if shared.args.cache_8bit:
|
||||||
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
|
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True)
|
||||||
else:
|
else:
|
||||||
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True)
|
||||||
|
|
||||||
|
if shared.args.autosplit:
|
||||||
|
self.ex_model.load_autosplit(self.ex_cache)
|
||||||
|
else:
|
||||||
|
split = None
|
||||||
|
if shared.args.gpu_split:
|
||||||
|
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||||
|
|
||||||
|
self.ex_model.load(split)
|
||||||
|
|
||||||
self.past_seq = None
|
self.past_seq = None
|
||||||
if shared.args.cfg_cache:
|
if shared.args.cfg_cache:
|
||||||
|
@ -78,6 +78,7 @@ loaders_and_params = OrderedDict({
|
|||||||
'no_flash_attn',
|
'no_flash_attn',
|
||||||
'num_experts_per_token',
|
'num_experts_per_token',
|
||||||
'cache_8bit',
|
'cache_8bit',
|
||||||
|
'autosplit',
|
||||||
'alpha_value',
|
'alpha_value',
|
||||||
'compress_pos_emb',
|
'compress_pos_emb',
|
||||||
'trust_remote_code',
|
'trust_remote_code',
|
||||||
@ -89,6 +90,7 @@ loaders_and_params = OrderedDict({
|
|||||||
'no_flash_attn',
|
'no_flash_attn',
|
||||||
'num_experts_per_token',
|
'num_experts_per_token',
|
||||||
'cache_8bit',
|
'cache_8bit',
|
||||||
|
'autosplit',
|
||||||
'alpha_value',
|
'alpha_value',
|
||||||
'compress_pos_emb',
|
'compress_pos_emb',
|
||||||
'exllamav2_info',
|
'exllamav2_info',
|
||||||
|
@ -134,6 +134,7 @@ group.add_argument('--row_split', action='store_true', help='Split the model by
|
|||||||
# ExLlamaV2
|
# ExLlamaV2
|
||||||
group = parser.add_argument_group('ExLlamaV2')
|
group = parser.add_argument_group('ExLlamaV2')
|
||||||
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
|
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
|
||||||
|
group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.')
|
||||||
group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
|
group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
|
||||||
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
||||||
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
|
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
|
||||||
|
@ -76,6 +76,7 @@ def list_model_elements():
|
|||||||
'no_flash_attn',
|
'no_flash_attn',
|
||||||
'num_experts_per_token',
|
'num_experts_per_token',
|
||||||
'cache_8bit',
|
'cache_8bit',
|
||||||
|
'autosplit',
|
||||||
'threads',
|
'threads',
|
||||||
'threads_batch',
|
'threads_batch',
|
||||||
'n_batch',
|
'n_batch',
|
||||||
|
@ -132,6 +132,7 @@ def create_ui():
|
|||||||
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
||||||
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
|
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
|
||||||
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
|
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
|
||||||
|
shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
|
||||||
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
|
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
|
||||||
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
|
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
|
||||||
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
||||||
|
Loading…
Reference in New Issue
Block a user