mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Make stop_everything work with non-streamed generation (#2848)
This commit is contained in:
parent
ec482f3dae
commit
e356f69b36
@ -9,6 +9,14 @@ import transformers
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||
return shared.stop_everything
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
@ -9,7 +9,8 @@ import torch
|
||||
import transformers
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.callbacks import Iteratorize, Stream
|
||||
from modules.callbacks import (Iteratorize, Stream,
|
||||
_StopEverythingStoppingCriteria)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||
from modules.logging_colors import logger
|
||||
@ -252,10 +253,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
# Find the eos tokens
|
||||
# Stopping criteria / eos token
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria());
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user