diff --git a/modules/RWKV.py b/modules/RWKV.py index 70deab28..836d31dc 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -50,11 +50,11 @@ class RWKVModel: return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) def generate_with_streaming(self, **kwargs): - iterable = Iteratorize(self.generate, kwargs, callback=None) - reply = kwargs['context'] - for token in iterable: - reply += token - yield reply + with Iteratorize(self.generate, kwargs, callback=None) as generator: + reply = kwargs['context'] + for token in generator: + reply += token + yield reply class RWKVTokenizer: def __init__(self): diff --git a/modules/callbacks.py b/modules/callbacks.py index 05e8fafa..e0d1c988 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -1,3 +1,4 @@ +import gc from queue import Queue from threading import Thread @@ -6,7 +7,6 @@ import transformers import modules.shared as shared - # Copied from https://github.com/PygmalionAI/gradio-ui/ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): @@ -52,17 +52,24 @@ class Iteratorize: self.q = Queue() self.sentinel = object() self.kwargs = kwargs + self.stop_now = False def _callback(val): + if self.stop_now: + raise ValueError self.q.put(val) def gentask(): - ret = self.mfunc(callback=_callback, **self.kwargs) + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass self.q.put(self.sentinel) if self.c_callback: self.c_callback(ret) - Thread(target=gentask).start() + self.thread = Thread(target=gentask) + self.thread.start() def __iter__(self): return self @@ -75,4 +82,16 @@ class Iteratorize: return obj def __del__(self): - pass + clear_torch_cache() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_now = True + clear_torch_cache() + +def clear_torch_cache(): + gc.collect() + if not shared.args.cpu: + torch.cuda.empty_cache() diff --git a/modules/text_generation.py b/modules/text_generation.py index 5d01c8cb..7f5aad5e 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -186,17 +186,18 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi return Iteratorize(generate_with_callback, kwargs, callback=None) yield formatted_outputs(original_question, shared.model_name) - for output in eval(f"generate_with_streaming({', '.join(generate_params)})"): - if shared.soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - reply = decode(output) + with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator: + for output in generator: + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + reply = decode(output) - if not (shared.args.chat or shared.args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") - yield formatted_outputs(reply, shared.model_name) + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + yield formatted_outputs(reply, shared.model_name) - if output[-1] == n: - break + if output[-1] == n: + break # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' else: