Exception handling (#454)

* Update text_generation.py
* Update extensions.py
This commit is contained in:
oobabooga 2023-03-20 13:36:52 -03:00 committed by GitHub
parent a90f507abe
commit 75a7a84ef2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 0 deletions

View File

@ -1,3 +1,5 @@
import traceback
import gradio as gr import gradio as gr
import extensions import extensions
@ -17,6 +19,7 @@ def load_extensions():
print('Ok.') print('Ok.')
except: except:
print('Fail.') print('Fail.')
traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line # This iterator returns the extensions in the order specified in the command-line
def iterator(): def iterator():

View File

@ -1,6 +1,7 @@
import gc import gc
import re import re
import time import time
import traceback
import numpy as np import numpy as np
import torch import torch
@ -110,6 +111,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# No need to generate 8 tokens at a time. # No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
except:
traceback.print_exc()
finally: finally:
t1 = time.time() t1 = time.time()
output = encode(reply)[0] output = encode(reply)[0]
@ -243,6 +246,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
except:
traceback.print_exc()
finally: finally:
t1 = time.time() t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)") print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")