From 75a7a84ef278cf24c5b59071f38c75ea5ab55aa4 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 20 Mar 2023 13:36:52 -0300 Subject: [PATCH] Exception handling (#454) * Update text_generation.py * Update extensions.py --- modules/extensions.py | 3 +++ modules/text_generation.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/modules/extensions.py b/modules/extensions.py index 836fbc60..dbc93840 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,3 +1,5 @@ +import traceback + import gradio as gr import extensions @@ -17,6 +19,7 @@ def load_extensions(): print('Ok.') except: print('Fail.') + traceback.print_exc() # This iterator returns the extensions in the order specified in the command-line def iterator(): diff --git a/modules/text_generation.py b/modules/text_generation.py index 9159975c..a70d490c 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,6 +1,7 @@ import gc import re import time +import traceback import numpy as np import torch @@ -110,6 +111,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): yield formatted_outputs(reply, shared.model_name) + except: + traceback.print_exc() finally: t1 = time.time() output = encode(reply)[0] @@ -243,6 +246,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi yield formatted_outputs(reply, shared.model_name) + except: + traceback.print_exc() finally: t1 = time.time() print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")