mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Exllamav2 lora support (#4229)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
1f5a2c5597
commit
8cce1f1126
@ -13,6 +13,8 @@ def add_lora_to_model(lora_names):
|
|||||||
add_lora_autogptq(lora_names)
|
add_lora_autogptq(lora_names)
|
||||||
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF'] or shared.args.loader == 'ExLlama':
|
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF'] or shared.args.loader == 'ExLlama':
|
||||||
add_lora_exllama(lora_names)
|
add_lora_exllama(lora_names)
|
||||||
|
elif shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader == ['ExLlamav2', 'ExLlamav2_HF']:
|
||||||
|
add_lora_exllamav2(lora_names)
|
||||||
else:
|
else:
|
||||||
add_lora_transformers(lora_names)
|
add_lora_transformers(lora_names)
|
||||||
|
|
||||||
@ -64,8 +66,36 @@ def add_lora_exllama(lora_names):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
|
def add_lora_exllamav2(lora_names):
|
||||||
|
|
||||||
|
from exllamav2 import ExLlamaV2Lora
|
||||||
|
|
||||||
|
if isinstance(shared.model.loras, list):
|
||||||
|
for lora in shared.model.loras:
|
||||||
|
lora.unload()
|
||||||
|
|
||||||
|
if len(lora_names) > 0:
|
||||||
|
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||||
|
shared.model.loras = []
|
||||||
|
for lora_name in lora_names:
|
||||||
|
lora_path = get_lora_path(lora_name)
|
||||||
|
if shared.model.__class__.__name__ == 'Exllamav2Model':
|
||||||
|
lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path))
|
||||||
|
else:
|
||||||
|
lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))
|
||||||
|
|
||||||
|
shared.model.loras.append(lora)
|
||||||
|
|
||||||
|
shared.lora_names = lora_names
|
||||||
|
else:
|
||||||
|
shared.lora_names = []
|
||||||
|
shared.model.loras = None
|
||||||
|
|
||||||
|
|
||||||
def add_lora_autogptq(lora_names):
|
def add_lora_autogptq(lora_names):
|
||||||
|
'''
|
||||||
|
Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
|
||||||
|
'''
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from auto_gptq import get_gptq_peft_model
|
from auto_gptq import get_gptq_peft_model
|
||||||
|
@ -98,7 +98,9 @@ class ExllamaModel:
|
|||||||
|
|
||||||
def get_logits(self, token_ids, **kwargs):
|
def get_logits(self, token_ids, **kwargs):
|
||||||
self.cache.current_seq_len = 0
|
self.cache.current_seq_len = 0
|
||||||
|
if token_ids.shape[-1] > 1:
|
||||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||||
|
|
||||||
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def generate_with_streaming(self, prompt, state):
|
||||||
|
@ -60,6 +60,7 @@ class Exllamav2Model:
|
|||||||
result.cache = cache
|
result.cache = cache
|
||||||
result.tokenizer = tokenizer
|
result.tokenizer = tokenizer
|
||||||
result.generator = generator
|
result.generator = generator
|
||||||
|
result.loras = None
|
||||||
return result, result
|
return result, result
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
@ -75,8 +76,10 @@ class Exllamav2Model:
|
|||||||
|
|
||||||
def get_logits(self, token_ids, **kwargs):
|
def get_logits(self, token_ids, **kwargs):
|
||||||
self.cache.current_seq_len = 0
|
self.cache.current_seq_len = 0
|
||||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
if token_ids.shape[-1] > 1:
|
||||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
|
||||||
|
|
||||||
|
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def generate_with_streaming(self, prompt, state):
|
||||||
settings = ExLlamaV2Sampler.Settings()
|
settings = ExLlamaV2Sampler.Settings()
|
||||||
@ -105,11 +108,11 @@ class Exllamav2Model:
|
|||||||
|
|
||||||
# _gen_begin_base
|
# _gen_begin_base
|
||||||
self.cache.current_seq_len = 0
|
self.cache.current_seq_len = 0
|
||||||
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
|
||||||
|
|
||||||
has_leading_space = False
|
has_leading_space = False
|
||||||
for i in range(max_new_tokens):
|
for i in range(max_new_tokens):
|
||||||
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu()
|
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None, loras=self.loras).float().cpu()
|
||||||
token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
|
token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
|
||||||
ids = torch.cat([ids, token], dim=1)
|
ids = torch.cat([ids, token], dim=1)
|
||||||
|
|
||||||
|
@ -33,8 +33,8 @@ class Exllamav2HF(PreTrainedModel):
|
|||||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||||
|
|
||||||
self.ex_model.load(split)
|
self.ex_model.load(split)
|
||||||
|
|
||||||
self.generation_config = GenerationConfig()
|
self.generation_config = GenerationConfig()
|
||||||
|
self.loras = None
|
||||||
|
|
||||||
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
||||||
self.past_seq = None
|
self.past_seq = None
|
||||||
@ -97,7 +97,7 @@ class Exllamav2HF(PreTrainedModel):
|
|||||||
reset = False
|
reset = False
|
||||||
ex_cache.current_seq_len = longest_prefix
|
ex_cache.current_seq_len = longest_prefix
|
||||||
if len(seq_tensor) - longest_prefix > 1:
|
if len(seq_tensor) - longest_prefix > 1:
|
||||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
|
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
||||||
elif len(seq_tensor) == longest_prefix:
|
elif len(seq_tensor) == longest_prefix:
|
||||||
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
|
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
|
||||||
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
|
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
|
||||||
@ -106,12 +106,12 @@ class Exllamav2HF(PreTrainedModel):
|
|||||||
if reset:
|
if reset:
|
||||||
ex_cache.current_seq_len = 0
|
ex_cache.current_seq_len = 0
|
||||||
if len(seq_tensor) > 1:
|
if len(seq_tensor) > 1:
|
||||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
|
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
||||||
|
|
||||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
|
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
|
||||||
else:
|
else:
|
||||||
ex_cache.current_seq_len = 0
|
ex_cache.current_seq_len = 0
|
||||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
|
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
|
||||||
|
|
||||||
if is_negative:
|
if is_negative:
|
||||||
self.past_seq_negative = seq_tensor
|
self.past_seq_negative = seq_tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user