diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 12212fec..5d05f5df 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -42,7 +42,6 @@ class LlamacppHF(PreTrainedModel): # Make the forward call seq_tensor = torch.tensor(seq) - self.cache = seq_tensor if labels is None: if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]): self.model.reset() @@ -50,13 +49,15 @@ class LlamacppHF(PreTrainedModel): else: self.model.eval([seq[-1]]) - logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device) + logits = torch.tensor(self.model.eval_logits[-1]).view(1, 1, -1).to(kwargs['input_ids'].device) else: self.model.reset() self.model.eval(seq) logits = torch.tensor(self.model.eval_logits) logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device) + self.cache = seq_tensor + # Based on transformers/models/llama/modeling_llama.py loss = None if labels is not None: