Exllamav2 lora support (#4229)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Forkoz 2023-10-14 19:12:41 +00:00 committed by GitHub
parent 1f5a2c5597
commit 8cce1f1126
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 12 deletions

View File

@ -13,6 +13,8 @@ def add_lora_to_model(lora_names):
add_lora_autogptq(lora_names)
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF'] or shared.args.loader == 'ExLlama':
add_lora_exllama(lora_names)
elif shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader == ['ExLlamav2', 'ExLlamav2_HF']:
add_lora_exllamav2(lora_names)
else:
add_lora_transformers(lora_names)
@ -64,8 +66,36 @@ def add_lora_exllama(lora_names):
return
# Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
def add_lora_exllamav2(lora_names):
from exllamav2 import ExLlamaV2Lora
if isinstance(shared.model.loras, list):
for lora in shared.model.loras:
lora.unload()
if len(lora_names) > 0:
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
shared.model.loras = []
for lora_name in lora_names:
lora_path = get_lora_path(lora_name)
if shared.model.__class__.__name__ == 'Exllamav2Model':
lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path))
else:
lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))
shared.model.loras.append(lora)
shared.lora_names = lora_names
else:
shared.lora_names = []
shared.model.loras = None
def add_lora_autogptq(lora_names):
'''
Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
'''
try:
from auto_gptq import get_gptq_peft_model

View File

@ -98,7 +98,9 @@ class ExllamaModel:
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
if token_ids.shape[-1] > 1:
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
def generate_with_streaming(self, prompt, state):

View File

@ -60,6 +60,7 @@ class Exllamav2Model:
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
result.loras = None
return result, result
def encode(self, string, **kwargs):
@ -75,8 +76,10 @@ class Exllamav2Model:
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
if token_ids.shape[-1] > 1:
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()
def generate_with_streaming(self, prompt, state):
settings = ExLlamaV2Sampler.Settings()
@ -105,12 +108,12 @@ class Exllamav2Model:
# _gen_begin_base
self.cache.current_seq_len = 0
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
has_leading_space = False
for i in range(max_new_tokens):
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu()
token, _, _= ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None, loras=self.loras).float().cpu()
token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
ids = torch.cat([ids, token], dim=1)
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):

View File

@ -33,8 +33,8 @@ class Exllamav2HF(PreTrainedModel):
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
self.ex_model.load(split)
self.generation_config = GenerationConfig()
self.loras = None
self.ex_cache = ExLlamaV2Cache(self.ex_model)
self.past_seq = None
@ -97,7 +97,7 @@ class Exllamav2HF(PreTrainedModel):
reset = False
ex_cache.current_seq_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
elif len(seq_tensor) == longest_prefix:
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
@ -106,12 +106,12 @@ class Exllamav2HF(PreTrainedModel):
if reset:
ex_cache.current_seq_len = 0
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
else:
ex_cache.current_seq_len = 0
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
if is_negative:
self.past_seq_negative = seq_tensor