mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Faster llamacpp_HF prefix matching
This commit is contained in:
parent
893a72a1c5
commit
745807dc03
@ -123,12 +123,12 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
# https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee
|
# https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee
|
||||||
if labels is None:
|
if labels is None:
|
||||||
if past_seq is not None:
|
if past_seq is not None:
|
||||||
longest_prefix = 0
|
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
||||||
for i in range(min(past_seq.shape[0], seq_tensor.shape[0])):
|
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
||||||
if past_seq[i] == seq_tensor[i]:
|
if len(indices) > 0:
|
||||||
longest_prefix += 1
|
longest_prefix = indices[0].item()
|
||||||
else:
|
else:
|
||||||
break
|
longest_prefix = min_length
|
||||||
|
|
||||||
if longest_prefix > 0:
|
if longest_prefix > 0:
|
||||||
self.model.n_tokens = longest_prefix
|
self.model.n_tokens = longest_prefix
|
||||||
|
Loading…
Reference in New Issue
Block a user