Load default model with --model flag

This commit is contained in:
oobabooga 2023-01-06 19:56:44 -03:00
parent ec2973f596
commit f54a13929f

View File

@ -2,23 +2,19 @@ import os
import re import re
import time import time
import glob import glob
from sys import exit
import torch import torch
import argparse
import gradio as gr import gradio as gr
import transformers import transformers
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
#model_name = "bloomz-7b1-p3" parser = argparse.ArgumentParser()
#model_name = 'gpt-j-6B-float16' parser.add_argument('--model', type=str, help='Name of the model to load by default')
#model_name = "opt-6.7b" args = parser.parse_args()
#model_name = 'opt-13b'
model_name = "gpt4chan_model_float16"
#model_name = 'galactica-6.7b'
#model_name = 'gpt-neox-20b'
#model_name = 'flan-t5'
#model_name = 'OPT-13B-Erebus'
loaded_preset = None loaded_preset = None
available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]"))))
def load_model(model_name): def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
@ -85,7 +81,24 @@ def generate_reply(question, temperature, max_length, inference_settings, select
return reply return reply
# Choosing the default model
if args.model is not None:
model_name = args.model
else:
if len(available_models == 0):
print("No models are available! Please download at least one.")
exit(0)
elif len(available_models) == 1:
i = 0
else:
print("The following models are available:\n")
for i,model in enumerate(available_models):
print(f"{i+1}. {model}")
print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
i = int(input())-1
model_name = available_models[i]
model, tokenizer = load_model(model_name) model, tokenizer = load_model(model_name)
if model_name.startswith('gpt4chan'): if model_name.startswith('gpt4chan'):
default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n" default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
else: else:
@ -98,7 +111,7 @@ interface = gr.Interface(
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200), gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"), gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
gr.Dropdown(choices=sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*") + glob.glob("torch-dumps/*")))), value=model_name), gr.Dropdown(choices=available_models, value=model_name),
], ],
outputs=[ outputs=[
gr.Textbox(placeholder="", lines=15), gr.Textbox(placeholder="", lines=15),