From d71465708c8b3ad9be9fdaa6313b0b6cc81bfd27 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:50:47 -0700 Subject: [PATCH] llamacpp_HF prefix matching --- modules/llamacpp_hf.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 6cba5d95..92de102c 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -117,14 +117,27 @@ class LlamacppHF(PreTrainedModel): seq = past_key_values + seq seq_tensor = torch.tensor(seq) + reset = True - # Make the forward call + # Make the forward call. The prefix-match code has been adapted from + # https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee if labels is None: - if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]): + if past_seq is not None: + longest_prefix = 0 + for i in range(min(past_seq.shape[0], seq_tensor.shape[0])): + if past_seq[i] == seq_tensor[i]: + longest_prefix += 1 + else: + break + + if longest_prefix > 0: + self.model.n_tokens = longest_prefix + self.model.eval(seq[longest_prefix:]) + reset = False + + if reset: self.model.reset() self.model.eval(seq) - else: - self.model.eval([seq[-1]]) logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device) else: