diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index 9beb2269..d7dada08 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -54,7 +54,15 @@ class ExllamaHF(PreTrainedModel): cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None if cache is None: cache = ExLlamaCache(self.ex_model) - self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora) + + nseq = seq[:-1] + for seqs in [nseq[i : i + 2048] for i in range(0, len(nseq), 2048)]: + self.ex_model.forward( + torch.tensor([seqs], dtype=torch.long), + cache, + preprocess_only=True, + lora=self.lora, + ) logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)