Merge pull request #5534 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2024-02-17 18:09:40 -03:00 committed by GitHub
commit 7838075990
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 16 deletions

View File

@ -51,20 +51,21 @@ class Exllamav2Model:
model = ExLlamaV2(config) model = ExLlamaV2(config)
if shared.args.cache_8bit: if not shared.args.autosplit:
cache = ExLlamaV2Cache_8bit(model, lazy=True)
else:
cache = ExLlamaV2Cache(model, lazy=True)
if shared.args.autosplit:
model.load_autosplit(cache)
else:
split = None split = None
if shared.args.gpu_split: if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
model.load(split) model.load(split)
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
else:
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)
if shared.args.autosplit:
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config) tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)

View File

@ -36,24 +36,26 @@ class Exllamav2HF(PreTrainedModel):
def __init__(self, config: ExLlamaV2Config): def __init__(self, config: ExLlamaV2Config):
super().__init__(PretrainedConfig()) super().__init__(PretrainedConfig())
self.ex_config = config self.ex_config = config
self.ex_model = ExLlamaV2(config)
self.loras = None self.loras = None
self.generation_config = GenerationConfig() self.generation_config = GenerationConfig()
if shared.args.cache_8bit: self.ex_model = ExLlamaV2(config)
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True)
if shared.args.autosplit: if not shared.args.autosplit:
self.ex_model.load_autosplit(self.ex_cache)
else:
split = None split = None
if shared.args.gpu_split: if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
self.ex_model.load(split) self.ex_model.load(split)
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
if shared.args.autosplit:
self.ex_model.load_autosplit(self.ex_cache)
self.past_seq = None self.past_seq = None
if shared.args.cfg_cache: if shared.args.cfg_cache:
if shared.args.cache_8bit: if shared.args.cache_8bit: