diff --git a/modules/exllama.py b/modules/exllama.py index 80607579..f685a445 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -1,10 +1,10 @@ -import sys from pathlib import Path from torch import version as torch_version from modules import shared from modules.logging_colors import logger +from modules.text_generation import get_max_prompt_length try: from exllama.generator import ExLlamaGenerator @@ -90,7 +90,11 @@ class ExllamaModel: self.generator.disallow_tokens(None) self.generator.end_beam_search() + + # Tokenizing the input ids = self.generator.tokenizer.encode(prompt) + ids = ids[:, -get_max_prompt_length(state):] + self.generator.gen_begin_reuse(ids) initial_len = self.generator.sequence[0].shape[0] has_leading_space = False