Better handle CUDA out of memory errors in chat mode

This commit is contained in:
oobabooga 2023-04-02 17:48:00 -03:00
parent b0890a7925
commit 5f3f3faa96

View File

@ -119,6 +119,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
# Generate # Generate
cumulative_reply = '' cumulative_reply = ''
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply reply = cumulative_reply + reply
@ -145,7 +146,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
if next_character_found: if next_character_found:
break break
cumulative_reply = reply if reply is not None:
cumulative_reply = reply
yield shared.history['visible'] yield shared.history['visible']
@ -162,6 +164,7 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
cumulative_reply = '' cumulative_reply = ''
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, check) reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
@ -169,7 +172,8 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
if next_character_found: if next_character_found:
break break
cumulative_reply = reply if reply is not None:
cumulative_reply = reply
yield reply yield reply