mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 09:40:20 +01:00
Fix streaming callback
This commit is contained in:
parent
1e400218e9
commit
7f06d551a3
@ -256,14 +256,14 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def generate_with_callback(callback=None, **kwargs):
|
def generate_with_callback(callback=None, *args, **kwargs):
|
||||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
shared.model.generate(**kwargs)
|
shared.model.generate(**kwargs)
|
||||||
|
|
||||||
def generate_with_streaming(**kwargs):
|
def generate_with_streaming(**kwargs):
|
||||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
return Iteratorize(generate_with_callback, [], kwargs, callback=None)
|
||||||
|
|
||||||
with generate_with_streaming(**generate_params) as generator:
|
with generate_with_streaming(**generate_params) as generator:
|
||||||
for output in generator:
|
for output in generator:
|
||||||
|
Loading…
Reference in New Issue
Block a user