From a199f2179999467b299fec3128e6298f5895c223 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 16 Jul 2023 20:49:48 -0700 Subject: [PATCH] Optimize llamacpp_hf a bit --- modules/llamacpp_hf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 12212fec..5d05f5df 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -42,7 +42,6 @@ class LlamacppHF(PreTrainedModel): # Make the forward call seq_tensor = torch.tensor(seq) - self.cache = seq_tensor if labels is None: if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]): self.model.reset() @@ -50,13 +49,15 @@ class LlamacppHF(PreTrainedModel): else: self.model.eval([seq[-1]]) - logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device) + logits = torch.tensor(self.model.eval_logits[-1]).view(1, 1, -1).to(kwargs['input_ids'].device) else: self.model.reset() self.model.eval(seq) logits = torch.tensor(self.model.eval_logits) logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device) + self.cache = seq_tensor + # Based on transformers/models/llama/modeling_llama.py loss = None if labels is not None: