ExllamaV2 tensor parallelism to increase multi gpu inference speeds (#6356)

This commit is contained in:
RandoInternetPreson 2024-09-27 23:26:03 -04:00 committed by GitHub
parent 301375834e
commit 46996f6519
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 17 deletions

View File

@ -7,6 +7,7 @@ from exllamav2 import (
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_TP,
ExLlamaV2Config,
ExLlamaV2Tokenizer
)
@ -54,21 +55,30 @@ class Exllamav2Model:
model = ExLlamaV2(config)
if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
if shared.args.enable_tp:
model.load_tp(split)
elif not shared.args.autosplit:
model.load(split)
# Determine the correct cache type
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_8bit
elif shared.args.cache_4bit:
cache = ExLlamaV2Cache_Q4(model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_Q4
else:
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache
if shared.args.autosplit:
# Use TP if specified
if shared.args.enable_tp:
cache = ExLlamaV2Cache_TP(model, base=cache_type)
else:
cache = cache_type(model, lazy=shared.args.autosplit)
if shared.args.autosplit and not shared.args.enable_tp:
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)

View File

@ -9,6 +9,7 @@ from exllamav2 import (
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_TP,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss
@ -42,21 +43,30 @@ class Exllamav2HF(PreTrainedModel):
self.ex_model = ExLlamaV2(config)
if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
self.ex_model.load(split)
if shared.args.enable_tp:
model.load_tp(split)
elif not shared.args.autosplit:
model.load(split)
# Determine the correct cache type
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_8bit
elif shared.args.cache_4bit:
self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_Q4
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache
if shared.args.autosplit:
# Use TP if specified
if shared.args.enable_tp:
self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
else:
self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)
if shared.args.autosplit and not shared.args.enable_tp:
self.ex_model.load_autosplit(self.ex_cache)
self.past_seq = None

View File

@ -146,6 +146,7 @@ group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to n
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.')
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.')
# AutoGPTQ
group = parser.add_argument_group('AutoGPTQ')