Add comments

This commit is contained in:
oobabooga 2023-01-06 02:26:33 -03:00
parent ee88d02292
commit deefa2e86a

View File

@ -44,6 +44,7 @@ def load_model(model_name):
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
# Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
@ -52,7 +53,7 @@ def fix_gpt4chan(s):
return s return s
def fn(question, temperature, max_length, inference_settings, selected_model): def generate_reply(question, temperature, max_length, inference_settings, selected_model):
global model, tokenizer, model_name, loaded_preset, preset global model, tokenizer, model_name, loaded_preset, preset
if selected_model != model_name: if selected_model != model_name:
@ -70,7 +71,6 @@ def fn(question, temperature, max_length, inference_settings, selected_model):
input_text = question input_text = question
input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda() input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
output = eval(f"model.generate(input_ids, {preset}).cuda()") output = eval(f"model.generate(input_ids, {preset}).cuda()")
reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = tokenizer.decode(output[0], skip_special_tokens=True)
@ -86,7 +86,7 @@ else:
default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:" default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:"
interface = gr.Interface( interface = gr.Interface(
fn, generate_reply,
inputs=[ inputs=[
gr.Textbox(value=default_text, lines=15), gr.Textbox(value=default_text, lines=15),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
@ -98,7 +98,7 @@ interface = gr.Interface(
gr.Textbox(placeholder="", lines=15), gr.Textbox(placeholder="", lines=15),
], ],
title="Text generation lab", title="Text generation lab",
description=f"Generate text using Large Language Models. Currently working with {model_name}", description=f"Generate text using Large Language Models.",
) )
interface.launch(share=False, server_name="0.0.0.0") interface.launch(share=False, server_name="0.0.0.0")