From 5f3f3faa96af4808d1173c9cc2d7452bd8d5a26e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 2 Apr 2023 17:48:00 -0300 Subject: [PATCH] Better handle CUDA out of memory errors in chat mode --- modules/chat.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index db79e7db..2e491d97 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -119,6 +119,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical # Generate cumulative_reply = '' for i in range(chat_generation_attempts): + reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply @@ -145,7 +146,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical if next_character_found: break - cumulative_reply = reply + if reply is not None: + cumulative_reply = reply yield shared.history['visible'] @@ -162,6 +164,7 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ cumulative_reply = '' for i in range(chat_generation_attempts): + reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply reply, next_character_found = extract_message_from_reply(reply, name1, name2, check) @@ -169,7 +172,8 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ if next_character_found: break - cumulative_reply = reply + if reply is not None: + cumulative_reply = reply yield reply