From 7f06d551a31d91053eb220c95e0dcea9c5ec91c5 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 16 Jun 2023 21:44:56 -0300 Subject: [PATCH] Fix streaming callback --- modules/text_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 7535d141..0d2f55c2 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -256,14 +256,14 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, # This is based on the trick of using 'stopping_criteria' to create an iterator. else: - def generate_with_callback(callback=None, **kwargs): + def generate_with_callback(callback=None, *args, **kwargs): kwargs['stopping_criteria'].append(Stream(callback_func=callback)) clear_torch_cache() with torch.no_grad(): shared.model.generate(**kwargs) def generate_with_streaming(**kwargs): - return Iteratorize(generate_with_callback, kwargs, callback=None) + return Iteratorize(generate_with_callback, [], kwargs, callback=None) with generate_with_streaming(**generate_params) as generator: for output in generator: