mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
Exllamav2 lora support (#4229)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
1f5a2c5597
commit
8cce1f1126
@ -13,6 +13,8 @@ def add_lora_to_model(lora_names):
|
||||
add_lora_autogptq(lora_names)
|
||||
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF'] or shared.args.loader == 'ExLlama':
|
||||
add_lora_exllama(lora_names)
|
||||
elif shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader == ['ExLlamav2', 'ExLlamav2_HF']:
|
||||
add_lora_exllamav2(lora_names)
|
||||
else:
|
||||
add_lora_transformers(lora_names)
|
||||
|
||||
@ -64,8 +66,36 @@ def add_lora_exllama(lora_names):
|
||||
return
|
||||
|
||||
|
||||
# Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
|
||||
def add_lora_exllamav2(lora_names):
|
||||
|
||||
from exllamav2 import ExLlamaV2Lora
|
||||
|
||||
if isinstance(shared.model.loras, list):
|
||||
for lora in shared.model.loras:
|
||||
lora.unload()
|
||||
|
||||
if len(lora_names) > 0:
|
||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||
shared.model.loras = []
|
||||
for lora_name in lora_names:
|
||||
lora_path = get_lora_path(lora_name)
|
||||
if shared.model.__class__.__name__ == 'Exllamav2Model':
|
||||
lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path))
|
||||
else:
|
||||
lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))
|
||||
|
||||
shared.model.loras.append(lora)
|
||||
|
||||
shared.lora_names = lora_names
|
||||
else:
|
||||
shared.lora_names = []
|
||||
shared.model.loras = None
|
||||
|
||||
|
||||
def add_lora_autogptq(lora_names):
|
||||
'''
|
||||
Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
|
||||
'''
|
||||
|
||||
try:
|
||||
from auto_gptq import get_gptq_peft_model
|
||||
|
@ -98,7 +98,9 @@ class ExllamaModel:
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
if token_ids.shape[-1] > 1:
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
|
@ -60,6 +60,7 @@ class Exllamav2Model:
|
||||
result.cache = cache
|
||||
result.tokenizer = tokenizer
|
||||
result.generator = generator
|
||||
result.loras = None
|
||||
return result, result
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
@ -75,8 +76,10 @@ class Exllamav2Model:
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
||||
if token_ids.shape[-1] > 1:
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
|
||||
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
@ -105,12 +108,12 @@ class Exllamav2Model:
|
||||
|
||||
# _gen_begin_base
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
|
||||
|
||||
has_leading_space = False
|
||||
for i in range(max_new_tokens):
|
||||
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu()
|
||||
token, _, _= ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
|
||||
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None, loras=self.loras).float().cpu()
|
||||
token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
|
||||
ids = torch.cat([ids, token], dim=1)
|
||||
|
||||
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
|
@ -33,8 +33,8 @@ class Exllamav2HF(PreTrainedModel):
|
||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||
|
||||
self.ex_model.load(split)
|
||||
|
||||
self.generation_config = GenerationConfig()
|
||||
self.loras = None
|
||||
|
||||
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
||||
self.past_seq = None
|
||||
@ -97,7 +97,7 @@ class Exllamav2HF(PreTrainedModel):
|
||||
reset = False
|
||||
ex_cache.current_seq_len = longest_prefix
|
||||
if len(seq_tensor) - longest_prefix > 1:
|
||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
||||
elif len(seq_tensor) == longest_prefix:
|
||||
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
|
||||
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
|
||||
@ -106,12 +106,12 @@ class Exllamav2HF(PreTrainedModel):
|
||||
if reset:
|
||||
ex_cache.current_seq_len = 0
|
||||
if len(seq_tensor) > 1:
|
||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
||||
|
||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
|
||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
|
||||
else:
|
||||
ex_cache.current_seq_len = 0
|
||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
|
||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
|
||||
|
||||
if is_negative:
|
||||
self.past_seq_negative = seq_tensor
|
||||
|
Loading…
Reference in New Issue
Block a user