llamacpp_HF prefix matching

This commit is contained in:
oobabooga 2023-09-17 11:50:47 -07:00
parent 763ea3bcb2
commit d71465708c

View File

@ -117,14 +117,27 @@ class LlamacppHF(PreTrainedModel):
seq = past_key_values + seq seq = past_key_values + seq
seq_tensor = torch.tensor(seq) seq_tensor = torch.tensor(seq)
reset = True
# Make the forward call # Make the forward call. The prefix-match code has been adapted from
# https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee
if labels is None: if labels is None:
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]): if past_seq is not None:
longest_prefix = 0
for i in range(min(past_seq.shape[0], seq_tensor.shape[0])):
if past_seq[i] == seq_tensor[i]:
longest_prefix += 1
else:
break
if longest_prefix > 0:
self.model.n_tokens = longest_prefix
self.model.eval(seq[longest_prefix:])
reset = False
if reset:
self.model.reset() self.model.reset()
self.model.eval(seq) self.model.eval(seq)
else:
self.model.eval([seq[-1]])
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device) logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device)
else: else: