mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix off-by-one error in exllama_hf caching logic (#4145)
This commit is contained in:
parent
b04c08378d
commit
cb26163a20
@ -94,6 +94,10 @@ class ExllamaHF(PreTrainedModel):
|
||||
ex_cache.current_seq_len = longest_prefix
|
||||
if len(seq_tensor) - longest_prefix > 1:
|
||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)
|
||||
elif len(seq_tensor) == longest_prefix:
|
||||
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
|
||||
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
|
||||
ex_cache.current_seq_len -= 1
|
||||
|
||||
if reset:
|
||||
ex_cache.current_seq_len = 0
|
||||
|
@ -98,6 +98,10 @@ class Exllamav2HF(PreTrainedModel):
|
||||
ex_cache.current_seq_len = longest_prefix
|
||||
if len(seq_tensor) - longest_prefix > 1:
|
||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||
elif len(seq_tensor) == longest_prefix:
|
||||
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
|
||||
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
|
||||
ex_cache.current_seq_len -= 1
|
||||
|
||||
if reset:
|
||||
ex_cache.current_seq_len = 0
|
||||
|
Loading…
Reference in New Issue
Block a user