Faster llamacpp_HF prefix matching

This commit is contained in:
oobabooga 2023-09-18 11:02:45 -07:00
parent 893a72a1c5
commit 745807dc03

View File

@ -123,12 +123,12 @@ class LlamacppHF(PreTrainedModel):
# https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee # https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee
if labels is None: if labels is None:
if past_seq is not None: if past_seq is not None:
longest_prefix = 0 min_length = min(past_seq.shape[0], seq_tensor.shape[0])
for i in range(min(past_seq.shape[0], seq_tensor.shape[0])): indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if past_seq[i] == seq_tensor[i]: if len(indices) > 0:
longest_prefix += 1 longest_prefix = indices[0].item()
else: else:
break longest_prefix = min_length
if longest_prefix > 0: if longest_prefix > 0:
self.model.n_tokens = longest_prefix self.model.n_tokens = longest_prefix