diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index 9e4701bf..3245ac87 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -77,17 +77,33 @@ class ExllamaHF(PreTrainedModel): seq = past_key_values + seq seq_tensor = torch.tensor(seq) + reset = True # Make the forward call if labels is None: - if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]): - ex_cache.current_seq_len = 0 - self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True, lora=self.lora) + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length - logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache, lora=self.lora).to(input_ids.device) + if longest_prefix > 0: + reset = False + ex_cache.current_seq_len = longest_prefix + if len(seq_tensor) - longest_prefix > 1: + self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora) + + if reset: + ex_cache.current_seq_len = 0 + if len(seq_tensor) > 1: + self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora) + + logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, lora=self.lora).to(input_ids.device) else: ex_cache.current_seq_len = 0 - logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False, lora=self.lora) + logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, lora=self.lora) if is_negative: self.past_seq_negative = seq_tensor diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 457942ac..6542ede9 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -81,18 +81,33 @@ class Exllamav2HF(PreTrainedModel): seq = past_key_values + seq seq_tensor = torch.tensor(seq) + reset = True # Make the forward call if labels is None: - if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]): - ex_cache.current_seq_len = 0 - self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True) + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length - logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache).to(input_ids.device) + if longest_prefix > 0: + reset = False + ex_cache.current_seq_len = longest_prefix + if len(seq_tensor) - longest_prefix > 1: + self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True) + + if reset: + ex_cache.current_seq_len = 0 + if len(seq_tensor) > 1: + self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True) + + logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device) else: ex_cache.current_seq_len = 0 - # logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False) - logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache) + logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False) if is_negative: self.past_seq_negative = seq_tensor