mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Make stop_everything work with non-streamed generation (#2848)
This commit is contained in:
parent
ec482f3dae
commit
e356f69b36
@ -9,6 +9,14 @@ import transformers
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
|
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
|
||||||
|
def __init__(self):
|
||||||
|
transformers.StoppingCriteria.__init__(self)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||||
|
return shared.stop_everything
|
||||||
|
|
||||||
|
|
||||||
class Stream(transformers.StoppingCriteria):
|
class Stream(transformers.StoppingCriteria):
|
||||||
def __init__(self, callback_func=None):
|
def __init__(self, callback_func=None):
|
||||||
self.callback_func = callback_func
|
self.callback_func = callback_func
|
||||||
|
@ -9,7 +9,8 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.callbacks import Iteratorize, Stream
|
from modules.callbacks import (Iteratorize, Stream,
|
||||||
|
_StopEverythingStoppingCriteria)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
@ -252,10 +253,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
|
||||||
# Find the eos tokens
|
# Stopping criteria / eos token
|
||||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
generate_params['eos_token_id'] = eos_token_ids
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||||
|
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria());
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user