Fix streaming callback

This commit is contained in:
oobabooga 2023-06-16 21:44:56 -03:00
parent 1e400218e9
commit 7f06d551a3

View File

@ -256,14 +256,14 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
# This is based on the trick of using 'stopping_criteria' to create an iterator. # This is based on the trick of using 'stopping_criteria' to create an iterator.
else: else:
def generate_with_callback(callback=None, **kwargs): def generate_with_callback(callback=None, *args, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback)) kwargs['stopping_criteria'].append(Stream(callback_func=callback))
clear_torch_cache() clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
shared.model.generate(**kwargs) shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs): def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None) return Iteratorize(generate_with_callback, [], kwargs, callback=None)
with generate_with_streaming(**generate_params) as generator: with generate_with_streaming(**generate_params) as generator:
for output in generator: for output in generator: