Remove annoying warnings

This commit is contained in:
oobabooga 2023-01-15 00:39:51 -03:00
parent d962e69496
commit fd220f827f

View File

@ -9,6 +9,7 @@ import gradio as gr
import transformers import transformers
from html_generator import * from html_generator import *
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -20,12 +21,15 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--no-listen', action='store_true', help='Make the webui unreachable from your local network.') parser.add_argument('--no-listen', action='store_true', help='Make the webui unreachable from your local network.')
args = parser.parse_args() args = parser.parse_args()
loaded_preset = None loaded_preset = None
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))))) available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
available_models = [item for item in available_models if not item.endswith('.txt')] available_models = [item for item in available_models if not item.endswith('.txt')]
available_models = sorted(available_models, key=str.lower) available_models = sorted(available_models, key=str.lower)
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt'))))) available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt')))))
transformers.logging.set_verbosity_error()
def load_model(model_name): def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
@ -188,10 +192,15 @@ if args.notebook:
elif args.chat: elif args.chat:
history = [] history = []
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): # This gets the new line characters right.
def chat_response_cleaner(text):
text = text.replace('\n', '\n\n') text = text.replace('\n', '\n\n')
text = re.sub(r"\n{3,}", "\n\n", text) text = re.sub(r"\n{3,}", "\n\n", text)
text = text.strip() text = text.strip()
return text
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
text = chat_response_cleaner(text)
question = context+'\n\n' question = context+'\n\n'
for i in range(len(history)): for i in range(len(history)):
@ -209,9 +218,7 @@ elif args.chat:
idx = reply.find(f"\n{name1}:") idx = reply.find(f"\n{name1}:")
if idx != -1: if idx != -1:
reply = reply[:idx] reply = reply[:idx]
reply = reply.replace('\n', '\n\n') reply = chat_response_cleaner(response)
reply = re.sub(r"\n{3,}", "\n\n", reply)
reply = reply.strip()
history.append((text, reply)) history.append((text, reply))
return history return history