From d6bd71db7f3200c2b1ef46123c07374848aed86a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 17 Feb 2024 12:42:22 -0800 Subject: [PATCH] ExLlamaV2: fix loading when autosplit is not set --- modules/exllamav2.py | 17 +++++++++-------- modules/exllamav2_hf.py | 18 ++++++++++-------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 239c2031..34072d0f 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -51,20 +51,21 @@ class Exllamav2Model: model = ExLlamaV2(config) - if shared.args.cache_8bit: - cache = ExLlamaV2Cache_8bit(model, lazy=True) - else: - cache = ExLlamaV2Cache(model, lazy=True) - - if shared.args.autosplit: - model.load_autosplit(cache) - else: + if not shared.args.autosplit: split = None if shared.args.gpu_split: split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] model.load(split) + if shared.args.cache_8bit: + cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit) + else: + cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit) + + if shared.args.autosplit: + model.load_autosplit(cache) + tokenizer = ExLlamaV2Tokenizer(config) generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index e5b35a44..1e21c2f1 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -36,24 +36,26 @@ class Exllamav2HF(PreTrainedModel): def __init__(self, config: ExLlamaV2Config): super().__init__(PretrainedConfig()) self.ex_config = config - self.ex_model = ExLlamaV2(config) self.loras = None self.generation_config = GenerationConfig() - if shared.args.cache_8bit: - self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True) - else: - self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True) + self.ex_model = ExLlamaV2(config) - if shared.args.autosplit: - self.ex_model.load_autosplit(self.ex_cache) - else: + if not shared.args.autosplit: split = None if shared.args.gpu_split: split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] self.ex_model.load(split) + if shared.args.cache_8bit: + self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit) + else: + self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit) + + if shared.args.autosplit: + self.ex_model.load_autosplit(self.ex_cache) + self.past_seq = None if shared.args.cfg_cache: if shared.args.cache_8bit: