mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
Add truncation to exllama
This commit is contained in:
parent
c21b73ff37
commit
1ba2e88551
@ -1,10 +1,10 @@
|
|||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from torch import version as torch_version
|
from torch import version as torch_version
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.text_generation import get_max_prompt_length
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from exllama.generator import ExLlamaGenerator
|
from exllama.generator import ExLlamaGenerator
|
||||||
@ -90,7 +90,11 @@ class ExllamaModel:
|
|||||||
self.generator.disallow_tokens(None)
|
self.generator.disallow_tokens(None)
|
||||||
|
|
||||||
self.generator.end_beam_search()
|
self.generator.end_beam_search()
|
||||||
|
|
||||||
|
# Tokenizing the input
|
||||||
ids = self.generator.tokenizer.encode(prompt)
|
ids = self.generator.tokenizer.encode(prompt)
|
||||||
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
|
|
||||||
self.generator.gen_begin_reuse(ids)
|
self.generator.gen_begin_reuse(ids)
|
||||||
initial_len = self.generator.sequence[0].shape[0]
|
initial_len = self.generator.sequence[0].shape[0]
|
||||||
has_leading_space = False
|
has_leading_space = False
|
||||||
|
Loading…
Reference in New Issue
Block a user