From 745807dc037155708080674e755da07ae335d08e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 18 Sep 2023 11:02:45 -0700 Subject: [PATCH] Faster llamacpp_HF prefix matching --- modules/llamacpp_hf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 92de102c..00da92ed 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -123,12 +123,12 @@ class LlamacppHF(PreTrainedModel): # https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee if labels is None: if past_seq is not None: - longest_prefix = 0 - for i in range(min(past_seq.shape[0], seq_tensor.shape[0])): - if past_seq[i] == seq_tensor[i]: - longest_prefix += 1 - else: - break + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length if longest_prefix > 0: self.model.n_tokens = longest_prefix