diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index 3245ac87..3ba1f3c3 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -94,6 +94,10 @@ class ExllamaHF(PreTrainedModel): ex_cache.current_seq_len = longest_prefix if len(seq_tensor) - longest_prefix > 1: self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora) + elif len(seq_tensor) == longest_prefix: + # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one, + # because we feed input_ids[-1] to forward() below, but that last token is already in the cache! + ex_cache.current_seq_len -= 1 if reset: ex_cache.current_seq_len = 0 diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 6542ede9..71cf513f 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -98,6 +98,10 @@ class Exllamav2HF(PreTrainedModel): ex_cache.current_seq_len = longest_prefix if len(seq_tensor) - longest_prefix > 1: self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True) + elif len(seq_tensor) == longest_prefix: + # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one, + # because we feed input_ids[-1] to forward() below, but that last token is already in the cache! + ex_cache.current_seq_len -= 1 if reset: ex_cache.current_seq_len = 0