mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
ExLlama_HF (v1 and v2) prefix matching
This commit is contained in:
parent
5075087461
commit
03dc69edc5
@ -77,17 +77,33 @@ class ExllamaHF(PreTrainedModel):
|
||||
seq = past_key_values + seq
|
||||
|
||||
seq_tensor = torch.tensor(seq)
|
||||
reset = True
|
||||
|
||||
# Make the forward call
|
||||
if labels is None:
|
||||
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
|
||||
ex_cache.current_seq_len = 0
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True, lora=self.lora)
|
||||
if past_seq is not None:
|
||||
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
||||
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
||||
if len(indices) > 0:
|
||||
longest_prefix = indices[0].item()
|
||||
else:
|
||||
longest_prefix = min_length
|
||||
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache, lora=self.lora).to(input_ids.device)
|
||||
if longest_prefix > 0:
|
||||
reset = False
|
||||
ex_cache.current_seq_len = longest_prefix
|
||||
if len(seq_tensor) - longest_prefix > 1:
|
||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)
|
||||
|
||||
if reset:
|
||||
ex_cache.current_seq_len = 0
|
||||
if len(seq_tensor) > 1:
|
||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)
|
||||
|
||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, lora=self.lora).to(input_ids.device)
|
||||
else:
|
||||
ex_cache.current_seq_len = 0
|
||||
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False, lora=self.lora)
|
||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, lora=self.lora)
|
||||
|
||||
if is_negative:
|
||||
self.past_seq_negative = seq_tensor
|
||||
|
@ -81,18 +81,33 @@ class Exllamav2HF(PreTrainedModel):
|
||||
seq = past_key_values + seq
|
||||
|
||||
seq_tensor = torch.tensor(seq)
|
||||
reset = True
|
||||
|
||||
# Make the forward call
|
||||
if labels is None:
|
||||
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
|
||||
ex_cache.current_seq_len = 0
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True)
|
||||
if past_seq is not None:
|
||||
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
||||
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
||||
if len(indices) > 0:
|
||||
longest_prefix = indices[0].item()
|
||||
else:
|
||||
longest_prefix = min_length
|
||||
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache).to(input_ids.device)
|
||||
if longest_prefix > 0:
|
||||
reset = False
|
||||
ex_cache.current_seq_len = longest_prefix
|
||||
if len(seq_tensor) - longest_prefix > 1:
|
||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||
|
||||
if reset:
|
||||
ex_cache.current_seq_len = 0
|
||||
if len(seq_tensor) > 1:
|
||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||
|
||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
|
||||
else:
|
||||
ex_cache.current_seq_len = 0
|
||||
# logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False)
|
||||
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache)
|
||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
|
||||
|
||||
if is_negative:
|
||||
self.past_seq_negative = seq_tensor
|
||||
|
Loading…
Reference in New Issue
Block a user