mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 01:09:22 +01:00
ExLlama_HF (v1 and v2) prefix matching
This commit is contained in:
parent
5075087461
commit
03dc69edc5
@ -77,17 +77,33 @@ class ExllamaHF(PreTrainedModel):
|
|||||||
seq = past_key_values + seq
|
seq = past_key_values + seq
|
||||||
|
|
||||||
seq_tensor = torch.tensor(seq)
|
seq_tensor = torch.tensor(seq)
|
||||||
|
reset = True
|
||||||
|
|
||||||
# Make the forward call
|
# Make the forward call
|
||||||
if labels is None:
|
if labels is None:
|
||||||
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
|
if past_seq is not None:
|
||||||
ex_cache.current_seq_len = 0
|
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
||||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True, lora=self.lora)
|
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
||||||
|
if len(indices) > 0:
|
||||||
|
longest_prefix = indices[0].item()
|
||||||
|
else:
|
||||||
|
longest_prefix = min_length
|
||||||
|
|
||||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache, lora=self.lora).to(input_ids.device)
|
if longest_prefix > 0:
|
||||||
|
reset = False
|
||||||
|
ex_cache.current_seq_len = longest_prefix
|
||||||
|
if len(seq_tensor) - longest_prefix > 1:
|
||||||
|
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)
|
||||||
|
|
||||||
|
if reset:
|
||||||
|
ex_cache.current_seq_len = 0
|
||||||
|
if len(seq_tensor) > 1:
|
||||||
|
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)
|
||||||
|
|
||||||
|
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, lora=self.lora).to(input_ids.device)
|
||||||
else:
|
else:
|
||||||
ex_cache.current_seq_len = 0
|
ex_cache.current_seq_len = 0
|
||||||
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False, lora=self.lora)
|
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, lora=self.lora)
|
||||||
|
|
||||||
if is_negative:
|
if is_negative:
|
||||||
self.past_seq_negative = seq_tensor
|
self.past_seq_negative = seq_tensor
|
||||||
|
@ -81,18 +81,33 @@ class Exllamav2HF(PreTrainedModel):
|
|||||||
seq = past_key_values + seq
|
seq = past_key_values + seq
|
||||||
|
|
||||||
seq_tensor = torch.tensor(seq)
|
seq_tensor = torch.tensor(seq)
|
||||||
|
reset = True
|
||||||
|
|
||||||
# Make the forward call
|
# Make the forward call
|
||||||
if labels is None:
|
if labels is None:
|
||||||
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
|
if past_seq is not None:
|
||||||
ex_cache.current_seq_len = 0
|
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
||||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True)
|
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
||||||
|
if len(indices) > 0:
|
||||||
|
longest_prefix = indices[0].item()
|
||||||
|
else:
|
||||||
|
longest_prefix = min_length
|
||||||
|
|
||||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache).to(input_ids.device)
|
if longest_prefix > 0:
|
||||||
|
reset = False
|
||||||
|
ex_cache.current_seq_len = longest_prefix
|
||||||
|
if len(seq_tensor) - longest_prefix > 1:
|
||||||
|
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||||
|
|
||||||
|
if reset:
|
||||||
|
ex_cache.current_seq_len = 0
|
||||||
|
if len(seq_tensor) > 1:
|
||||||
|
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||||
|
|
||||||
|
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
|
||||||
else:
|
else:
|
||||||
ex_cache.current_seq_len = 0
|
ex_cache.current_seq_len = 0
|
||||||
# logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False)
|
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
|
||||||
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache)
|
|
||||||
|
|
||||||
if is_negative:
|
if is_negative:
|
||||||
self.past_seq_negative = seq_tensor
|
self.past_seq_negative = seq_tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user