Add truncation to exllama

This commit is contained in:
oobabooga 2023-07-07 09:09:23 -07:00
parent c21b73ff37
commit 1ba2e88551

View File

@ -1,10 +1,10 @@
import sys
from pathlib import Path from pathlib import Path
from torch import version as torch_version from torch import version as torch_version
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length
try: try:
from exllama.generator import ExLlamaGenerator from exllama.generator import ExLlamaGenerator
@ -90,7 +90,11 @@ class ExllamaModel:
self.generator.disallow_tokens(None) self.generator.disallow_tokens(None)
self.generator.end_beam_search() self.generator.end_beam_search()
# Tokenizing the input
ids = self.generator.tokenizer.encode(prompt) ids = self.generator.tokenizer.encode(prompt)
ids = ids[:, -get_max_prompt_length(state):]
self.generator.gen_begin_reuse(ids) self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0] initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False has_leading_space = False