From 8cce1f1126049d5584382fef9916ab6f915f5e7e Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Sat, 14 Oct 2023 19:12:41 +0000 Subject: [PATCH] Exllamav2 lora support (#4229) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- modules/LoRA.py | 32 +++++++++++++++++++++++++++++++- modules/exllama.py | 4 +++- modules/exllamav2.py | 13 ++++++++----- modules/exllamav2_hf.py | 10 +++++----- 4 files changed, 47 insertions(+), 12 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index 10020552..b3997d80 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -13,6 +13,8 @@ def add_lora_to_model(lora_names): add_lora_autogptq(lora_names) elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF'] or shared.args.loader == 'ExLlama': add_lora_exllama(lora_names) + elif shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader == ['ExLlamav2', 'ExLlamav2_HF']: + add_lora_exllamav2(lora_names) else: add_lora_transformers(lora_names) @@ -64,8 +66,36 @@ def add_lora_exllama(lora_names): return -# Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing +def add_lora_exllamav2(lora_names): + + from exllamav2 import ExLlamaV2Lora + + if isinstance(shared.model.loras, list): + for lora in shared.model.loras: + lora.unload() + + if len(lora_names) > 0: + logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) + shared.model.loras = [] + for lora_name in lora_names: + lora_path = get_lora_path(lora_name) + if shared.model.__class__.__name__ == 'Exllamav2Model': + lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path)) + else: + lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path)) + + shared.model.loras.append(lora) + + shared.lora_names = lora_names + else: + shared.lora_names = [] + shared.model.loras = None + + def add_lora_autogptq(lora_names): + ''' + Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing + ''' try: from auto_gptq import get_gptq_peft_model diff --git a/modules/exllama.py b/modules/exllama.py index cb92344e..4257ee07 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -98,7 +98,9 @@ class ExllamaModel: def get_logits(self, token_ids, **kwargs): self.cache.current_seq_len = 0 - self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + if token_ids.shape[-1] > 1: + self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu() def generate_with_streaming(self, prompt, state): diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 278d3943..a75ede46 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -60,6 +60,7 @@ class Exllamav2Model: result.cache = cache result.tokenizer = tokenizer result.generator = generator + result.loras = None return result, result def encode(self, string, **kwargs): @@ -75,8 +76,10 @@ class Exllamav2Model: def get_logits(self, token_ids, **kwargs): self.cache.current_seq_len = 0 - self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) - return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu() + if token_ids.shape[-1] > 1: + self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras) + + return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu() def generate_with_streaming(self, prompt, state): settings = ExLlamaV2Sampler.Settings() @@ -105,12 +108,12 @@ class Exllamav2Model: # _gen_begin_base self.cache.current_seq_len = 0 - self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras) has_leading_space = False for i in range(max_new_tokens): - logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu() - token, _, _= ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer) + logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None, loras=self.loras).float().cpu() + token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer) ids = torch.cat([ids, token], dim=1) if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 71cf513f..e12a0717 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -33,8 +33,8 @@ class Exllamav2HF(PreTrainedModel): split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] self.ex_model.load(split) - self.generation_config = GenerationConfig() + self.loras = None self.ex_cache = ExLlamaV2Cache(self.ex_model) self.past_seq = None @@ -97,7 +97,7 @@ class Exllamav2HF(PreTrainedModel): reset = False ex_cache.current_seq_len = longest_prefix if len(seq_tensor) - longest_prefix > 1: - self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True) + self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras) elif len(seq_tensor) == longest_prefix: # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one, # because we feed input_ids[-1] to forward() below, but that last token is already in the cache! @@ -106,12 +106,12 @@ class Exllamav2HF(PreTrainedModel): if reset: ex_cache.current_seq_len = 0 if len(seq_tensor) > 1: - self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True) + self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras) - logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device) + logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device) else: ex_cache.current_seq_len = 0 - logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False) + logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras) if is_negative: self.past_seq_negative = seq_tensor