mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 17:06:53 +01:00
ExllamaV2 tensor parallelism to increase multi gpu inference speeds (#6356)
This commit is contained in:
parent
301375834e
commit
46996f6519
@ -7,6 +7,7 @@ from exllamav2 import (
|
|||||||
ExLlamaV2Cache,
|
ExLlamaV2Cache,
|
||||||
ExLlamaV2Cache_8bit,
|
ExLlamaV2Cache_8bit,
|
||||||
ExLlamaV2Cache_Q4,
|
ExLlamaV2Cache_Q4,
|
||||||
|
ExLlamaV2Cache_TP,
|
||||||
ExLlamaV2Config,
|
ExLlamaV2Config,
|
||||||
ExLlamaV2Tokenizer
|
ExLlamaV2Tokenizer
|
||||||
)
|
)
|
||||||
@ -54,21 +55,30 @@ class Exllamav2Model:
|
|||||||
|
|
||||||
model = ExLlamaV2(config)
|
model = ExLlamaV2(config)
|
||||||
|
|
||||||
if not shared.args.autosplit:
|
|
||||||
split = None
|
split = None
|
||||||
if shared.args.gpu_split:
|
if shared.args.gpu_split:
|
||||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||||
|
|
||||||
|
if shared.args.enable_tp:
|
||||||
|
model.load_tp(split)
|
||||||
|
elif not shared.args.autosplit:
|
||||||
model.load(split)
|
model.load(split)
|
||||||
|
|
||||||
|
# Determine the correct cache type
|
||||||
if shared.args.cache_8bit:
|
if shared.args.cache_8bit:
|
||||||
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
|
cache_type = ExLlamaV2Cache_8bit
|
||||||
elif shared.args.cache_4bit:
|
elif shared.args.cache_4bit:
|
||||||
cache = ExLlamaV2Cache_Q4(model, lazy=shared.args.autosplit)
|
cache_type = ExLlamaV2Cache_Q4
|
||||||
else:
|
else:
|
||||||
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)
|
cache_type = ExLlamaV2Cache
|
||||||
|
|
||||||
if shared.args.autosplit:
|
# Use TP if specified
|
||||||
|
if shared.args.enable_tp:
|
||||||
|
cache = ExLlamaV2Cache_TP(model, base=cache_type)
|
||||||
|
else:
|
||||||
|
cache = cache_type(model, lazy=shared.args.autosplit)
|
||||||
|
|
||||||
|
if shared.args.autosplit and not shared.args.enable_tp:
|
||||||
model.load_autosplit(cache)
|
model.load_autosplit(cache)
|
||||||
|
|
||||||
tokenizer = ExLlamaV2Tokenizer(config)
|
tokenizer = ExLlamaV2Tokenizer(config)
|
||||||
|
@ -9,6 +9,7 @@ from exllamav2 import (
|
|||||||
ExLlamaV2Cache,
|
ExLlamaV2Cache,
|
||||||
ExLlamaV2Cache_8bit,
|
ExLlamaV2Cache_8bit,
|
||||||
ExLlamaV2Cache_Q4,
|
ExLlamaV2Cache_Q4,
|
||||||
|
ExLlamaV2Cache_TP,
|
||||||
ExLlamaV2Config
|
ExLlamaV2Config
|
||||||
)
|
)
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
@ -42,21 +43,30 @@ class Exllamav2HF(PreTrainedModel):
|
|||||||
|
|
||||||
self.ex_model = ExLlamaV2(config)
|
self.ex_model = ExLlamaV2(config)
|
||||||
|
|
||||||
if not shared.args.autosplit:
|
|
||||||
split = None
|
split = None
|
||||||
if shared.args.gpu_split:
|
if shared.args.gpu_split:
|
||||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||||
|
|
||||||
self.ex_model.load(split)
|
if shared.args.enable_tp:
|
||||||
|
model.load_tp(split)
|
||||||
|
elif not shared.args.autosplit:
|
||||||
|
model.load(split)
|
||||||
|
|
||||||
|
# Determine the correct cache type
|
||||||
if shared.args.cache_8bit:
|
if shared.args.cache_8bit:
|
||||||
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
|
cache_type = ExLlamaV2Cache_8bit
|
||||||
elif shared.args.cache_4bit:
|
elif shared.args.cache_4bit:
|
||||||
self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)
|
cache_type = ExLlamaV2Cache_Q4
|
||||||
else:
|
else:
|
||||||
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
|
cache_type = ExLlamaV2Cache
|
||||||
|
|
||||||
if shared.args.autosplit:
|
# Use TP if specified
|
||||||
|
if shared.args.enable_tp:
|
||||||
|
self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
|
||||||
|
else:
|
||||||
|
self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)
|
||||||
|
|
||||||
|
if shared.args.autosplit and not shared.args.enable_tp:
|
||||||
self.ex_model.load_autosplit(self.ex_cache)
|
self.ex_model.load_autosplit(self.ex_cache)
|
||||||
|
|
||||||
self.past_seq = None
|
self.past_seq = None
|
||||||
|
@ -146,6 +146,7 @@ group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to n
|
|||||||
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
|
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
|
||||||
group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.')
|
group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.')
|
||||||
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
|
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
|
||||||
|
group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.')
|
||||||
|
|
||||||
# AutoGPTQ
|
# AutoGPTQ
|
||||||
group = parser.add_argument_group('AutoGPTQ')
|
group = parser.add_argument_group('AutoGPTQ')
|
||||||
|
Loading…
Reference in New Issue
Block a user