mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
llamacpp_HF prefix matching
This commit is contained in:
parent
763ea3bcb2
commit
d71465708c
@ -117,14 +117,27 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
seq = past_key_values + seq
|
seq = past_key_values + seq
|
||||||
|
|
||||||
seq_tensor = torch.tensor(seq)
|
seq_tensor = torch.tensor(seq)
|
||||||
|
reset = True
|
||||||
|
|
||||||
# Make the forward call
|
# Make the forward call. The prefix-match code has been adapted from
|
||||||
|
# https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee
|
||||||
if labels is None:
|
if labels is None:
|
||||||
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
|
if past_seq is not None:
|
||||||
|
longest_prefix = 0
|
||||||
|
for i in range(min(past_seq.shape[0], seq_tensor.shape[0])):
|
||||||
|
if past_seq[i] == seq_tensor[i]:
|
||||||
|
longest_prefix += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if longest_prefix > 0:
|
||||||
|
self.model.n_tokens = longest_prefix
|
||||||
|
self.model.eval(seq[longest_prefix:])
|
||||||
|
reset = False
|
||||||
|
|
||||||
|
if reset:
|
||||||
self.model.reset()
|
self.model.reset()
|
||||||
self.model.eval(seq)
|
self.model.eval(seq)
|
||||||
else:
|
|
||||||
self.model.eval([seq[-1]])
|
|
||||||
|
|
||||||
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device)
|
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user