diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 551ed498..239c2031 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -51,18 +51,21 @@ class Exllamav2Model: model = ExLlamaV2(config) - split = None - if shared.args.gpu_split: - split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + if shared.args.cache_8bit: + cache = ExLlamaV2Cache_8bit(model, lazy=True) + else: + cache = ExLlamaV2Cache(model, lazy=True) - model.load(split) + if shared.args.autosplit: + model.load_autosplit(cache) + else: + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + + model.load(split) tokenizer = ExLlamaV2Tokenizer(config) - if shared.args.cache_8bit: - cache = ExLlamaV2Cache_8bit(model) - else: - cache = ExLlamaV2Cache(model) - generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) result = self() diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 944c39dd..e5b35a44 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -37,18 +37,22 @@ class Exllamav2HF(PreTrainedModel): super().__init__(PretrainedConfig()) self.ex_config = config self.ex_model = ExLlamaV2(config) - split = None - if shared.args.gpu_split: - split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] - - self.ex_model.load(split) - self.generation_config = GenerationConfig() self.loras = None + self.generation_config = GenerationConfig() if shared.args.cache_8bit: - self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model) + self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True) else: - self.ex_cache = ExLlamaV2Cache(self.ex_model) + self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True) + + if shared.args.autosplit: + self.ex_model.load_autosplit(self.ex_cache) + else: + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + + self.ex_model.load(split) self.past_seq = None if shared.args.cfg_cache: diff --git a/modules/loaders.py b/modules/loaders.py index 26b7c5e2..08a7f229 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -78,6 +78,7 @@ loaders_and_params = OrderedDict({ 'no_flash_attn', 'num_experts_per_token', 'cache_8bit', + 'autosplit', 'alpha_value', 'compress_pos_emb', 'trust_remote_code', @@ -89,6 +90,7 @@ loaders_and_params = OrderedDict({ 'no_flash_attn', 'num_experts_per_token', 'cache_8bit', + 'autosplit', 'alpha_value', 'compress_pos_emb', 'exllamav2_info', diff --git a/modules/shared.py b/modules/shared.py index d8aef367..7bef04bf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -134,6 +134,7 @@ group.add_argument('--row_split', action='store_true', help='Split the model by # ExLlamaV2 group = parser.add_argument_group('ExLlamaV2') group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.') +group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.') group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.') group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.') group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') diff --git a/modules/ui.py b/modules/ui.py index 06498f69..bb5a3339 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -76,6 +76,7 @@ def list_model_elements(): 'no_flash_attn', 'num_experts_per_token', 'cache_8bit', + 'autosplit', 'threads', 'threads_batch', 'n_batch', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 94b01937..14bc7caf 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -132,6 +132,7 @@ def create_ui(): shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk) shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16) shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') + shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.') shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.') shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')