Fix memory leak in new streaming (second attempt)

This commit is contained in:
oobabooga 2023-03-11 23:14:49 -03:00
parent 92fe947721
commit 37f0166b2d
2 changed files with 4 additions and 2 deletions

View File

@ -49,7 +49,7 @@ class Iteratorize:
def __init__(self, func, kwargs={}, callback=None): def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func self.mfunc=func
self.c_callback=callback self.c_callback=callback
self.q = Queue(maxsize=1) self.q = Queue()
self.sentinel = object() self.sentinel = object()
self.kwargs = kwargs self.kwargs = kwargs
@ -73,3 +73,6 @@ class Iteratorize:
raise StopIteration raise StopIteration
else: else:
return obj return obj
def __del__(self):
pass

View File

@ -187,7 +187,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
yield formatted_outputs(original_question, shared.model_name) yield formatted_outputs(original_question, shared.model_name)
for output in eval(f"generate_with_streaming({', '.join(generate_params)})"): for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
print(print('Used vram in gib:', torch.cuda.memory_allocated() / 1024**3))
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output) reply = decode(output)