2023-01-06 05:41:52 +01:00
import os
2022-12-21 17:27:31 +01:00
import re
2023-01-06 05:33:21 +01:00
import time
import glob
2023-01-06 23:56:44 +01:00
from sys import exit
2022-12-21 17:27:31 +01:00
import torch
2023-01-06 23:56:44 +01:00
import argparse
2022-12-21 17:27:31 +01:00
import gradio as gr
import transformers
2023-01-07 03:14:08 +01:00
from html_generator import *
2022-12-21 17:27:31 +01:00
from transformers import AutoTokenizer
from transformers import GPTJForCausalLM , AutoModelForCausalLM , AutoModelForSeq2SeqLM , OPTForCausalLM , T5Tokenizer , T5ForConditionalGeneration , GPTJModel , AutoModel
2023-01-07 03:14:08 +01:00
2023-01-06 23:56:44 +01:00
parser = argparse . ArgumentParser ( )
2023-01-07 00:22:26 +01:00
parser . add_argument ( ' --model ' , type = str , help = ' Name of the model to load by default. ' )
parser . add_argument ( ' --notebook ' , action = ' store_true ' , help = ' Launch the webui in notebook mode, where the output is written to the same text box as the input. ' )
2023-01-06 23:56:44 +01:00
args = parser . parse_args ( )
2023-01-06 06:06:59 +01:00
loaded_preset = None
2023-01-07 03:14:08 +01:00
available_models = sorted ( set ( map ( lambda x : x . split ( ' / ' ) [ - 1 ] . replace ( ' .pt ' , ' ' ) , glob . glob ( " models/* " ) + glob . glob ( " torch-dumps/* " ) ) ) )
available_models = [ item for item in available_models if not item . endswith ( ' .txt ' ) ]
#available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]"))))
2023-01-06 05:33:21 +01:00
2022-12-21 17:27:31 +01:00
def load_model ( model_name ) :
2023-01-06 05:41:52 +01:00
print ( f " Loading { model_name } ... " )
2022-12-21 17:27:31 +01:00
t0 = time . time ( )
2023-01-06 05:41:52 +01:00
2023-01-06 06:54:33 +01:00
# Loading the model
2023-01-06 05:41:52 +01:00
if os . path . exists ( f " torch-dumps/ { model_name } .pt " ) :
print ( " Loading in .pt format... " )
model = torch . load ( f " torch-dumps/ { model_name } .pt " ) . cuda ( )
2023-01-06 06:54:33 +01:00
elif model_name . lower ( ) . startswith ( ( ' gpt-neo ' , ' opt- ' , ' galactica ' ) ) :
if any ( size in model_name for size in ( ' 13b ' , ' 20b ' , ' 30b ' ) ) :
model = AutoModelForCausalLM . from_pretrained ( f " models/ { model_name } " , device_map = ' auto ' , load_in_8bit = True )
else :
model = AutoModelForCausalLM . from_pretrained ( f " models/ { model_name } " , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
2022-12-21 17:27:31 +01:00
elif model_name in [ ' gpt-j-6B ' ] :
model = AutoModelForCausalLM . from_pretrained ( f " models/ { model_name } " , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
2023-01-06 06:06:59 +01:00
elif model_name in [ ' flan-t5 ' , ' t5-large ' ] :
2022-12-21 17:27:31 +01:00
model = T5ForConditionalGeneration . from_pretrained ( f " models/ { model_name } " ) . cuda ( )
2023-01-06 06:54:33 +01:00
else :
model = AutoModelForCausalLM . from_pretrained ( f " models/ { model_name } " , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
2022-12-21 17:27:31 +01:00
2023-01-06 06:54:33 +01:00
# Loading the tokenizer
if model_name . startswith ( ' gpt4chan ' ) :
2022-12-21 17:27:31 +01:00
tokenizer = AutoTokenizer . from_pretrained ( " models/gpt-j-6B/ " )
elif model_name in [ ' flan-t5 ' ] :
tokenizer = T5Tokenizer . from_pretrained ( f " models/ { model_name } / " )
else :
tokenizer = AutoTokenizer . from_pretrained ( f " models/ { model_name } / " )
2023-01-06 06:06:59 +01:00
print ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2022-12-21 17:27:31 +01:00
return model , tokenizer
2023-01-06 06:26:33 +01:00
# Removes empty replies from gpt4chan outputs
2022-12-21 17:27:31 +01:00
def fix_gpt4chan ( s ) :
for i in range ( 10 ) :
s = re . sub ( " --- [0-9]* \n >>[0-9]* \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n * \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n \n \n --- " , " --- " , s )
return s
2023-01-07 05:56:21 +01:00
def fix_galactica ( s ) :
s = s . replace ( r ' \ [ ' , r ' $ ' )
s = s . replace ( r ' \ ] ' , r ' $ ' )
2023-01-07 16:13:09 +01:00
s = s . replace ( r ' \ ( ' , r ' $ ' )
s = s . replace ( r ' \ ) ' , r ' $ ' )
s = s . replace ( r ' $$ ' , r ' $ ' )
2023-01-07 05:56:21 +01:00
return s
2023-01-06 06:26:33 +01:00
def generate_reply ( question , temperature , max_length , inference_settings , selected_model ) :
2023-01-06 06:06:59 +01:00
global model , tokenizer , model_name , loaded_preset , preset
2022-12-21 17:27:31 +01:00
if selected_model != model_name :
model_name = selected_model
model = None
tokenier = None
torch . cuda . empty_cache ( )
model , tokenizer = load_model ( model_name )
2023-01-06 06:06:59 +01:00
if inference_settings != loaded_preset :
2023-01-06 05:33:21 +01:00
with open ( f ' presets/ { inference_settings } .txt ' , ' r ' ) as infile :
preset = infile . read ( )
2023-01-06 06:06:59 +01:00
loaded_preset = inference_settings
2022-12-21 17:27:31 +01:00
torch . cuda . empty_cache ( )
input_text = question
input_ids = tokenizer . encode ( str ( input_text ) , return_tensors = ' pt ' ) . cuda ( )
2023-01-06 05:33:21 +01:00
output = eval ( f " model.generate(input_ids, { preset } ).cuda() " )
2022-12-21 17:27:31 +01:00
reply = tokenizer . decode ( output [ 0 ] , skip_special_tokens = True )
2023-01-07 03:14:08 +01:00
2023-01-07 00:22:26 +01:00
if model_name . lower ( ) . startswith ( ' galactica ' ) :
2023-01-07 05:56:21 +01:00
reply = fix_galactica ( reply )
2023-01-07 03:14:08 +01:00
return reply , reply , ' Only applicable for gpt4chan. '
elif model_name . lower ( ) . startswith ( ' gpt4chan ' ) :
2023-01-07 05:56:21 +01:00
reply = fix_gpt4chan ( reply )
2023-01-07 03:14:08 +01:00
return reply , ' Only applicable for galactica models. ' , generate_html ( reply )
2023-01-07 00:22:26 +01:00
else :
2023-01-07 03:14:08 +01:00
return reply , ' Only applicable for galactica models. ' , ' Only applicable for gpt4chan. '
2022-12-21 17:27:31 +01:00
2023-01-06 23:56:44 +01:00
# Choosing the default model
if args . model is not None :
model_name = args . model
else :
2023-01-07 02:05:37 +01:00
if len ( available_models ) == 0 :
2023-01-06 23:56:44 +01:00
print ( " No models are available! Please download at least one. " )
exit ( 0 )
elif len ( available_models ) == 1 :
i = 0
else :
print ( " The following models are available: \n " )
for i , model in enumerate ( available_models ) :
print ( f " { i + 1 } . { model } " )
print ( f " \n Which one do you want to load? 1- { len ( available_models ) } \n " )
i = int ( input ( ) ) - 1
model_name = available_models [ i ]
2022-12-21 17:27:31 +01:00
model , tokenizer = load_model ( model_name )
2023-01-06 23:56:44 +01:00
2022-12-21 17:27:31 +01:00
if model_name . startswith ( ' gpt4chan ' ) :
default_text = " ----- \n --- 865467536 \n Input text \n --- 865467537 \n "
else :
default_text = " Common sense questions and answers \n \n Question: \n Factual answer: "
2023-01-07 00:22:26 +01:00
if args . notebook :
with gr . Blocks ( ) as interface :
gr . Markdown (
f """
# Text generation lab
Generate text using Large Language Models .
"""
)
2023-01-07 02:05:37 +01:00
with gr . Tab ( ' Raw ' ) :
textbox = gr . Textbox ( value = default_text , lines = 23 )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-07 03:14:08 +01:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-07 00:22:26 +01:00
btn = gr . Button ( " Generate " )
2023-01-07 02:05:37 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
temp_slider = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Temperature ' , value = 0.7 )
length_slider = gr . Slider ( minimum = 1 , maximum = 2000 , step = 1 , label = ' max_length ' , value = 200 )
with gr . Column ( ) :
2023-01-07 04:49:47 +01:00
preset_menu = gr . Dropdown ( choices = list ( map ( lambda x : x . split ( ' / ' ) [ - 1 ] . split ( ' . ' ) [ 0 ] , glob . glob ( " presets/*.txt " ) ) ) , value = " NovelAI-Sphinx Moth " , label = ' Preset ' )
2023-01-07 02:05:37 +01:00
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
2023-01-07 00:22:26 +01:00
2023-01-07 03:14:08 +01:00
btn . click ( generate_reply , [ textbox , temp_slider , length_slider , preset_menu , model_menu ] , [ textbox , markdown , html ] , show_progress = False )
2023-01-07 00:22:26 +01:00
else :
2023-01-07 02:05:37 +01:00
with gr . Blocks ( ) as interface :
gr . Markdown (
f """
# Text generation lab
Generate text using Large Language Models .
"""
)
with gr . Row ( ) :
with gr . Column ( ) :
textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Input ' )
temp_slider = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Temperature ' , value = 0.7 )
length_slider = gr . Slider ( minimum = 1 , maximum = 2000 , step = 1 , label = ' max_length ' , value = 200 )
2023-01-07 04:49:47 +01:00
preset_menu = gr . Dropdown ( choices = list ( map ( lambda x : x . split ( ' / ' ) [ - 1 ] . split ( ' . ' ) [ 0 ] , glob . glob ( " presets/*.txt " ) ) ) , value = " NovelAI-Sphinx Moth " , label = ' Preset ' )
2023-01-07 02:05:37 +01:00
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
btn = gr . Button ( " Generate " )
with gr . Column ( ) :
with gr . Tab ( ' Raw ' ) :
output_textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Output ' )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-07 03:14:08 +01:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-07 02:05:37 +01:00
2023-01-07 03:14:08 +01:00
btn . click ( generate_reply , [ textbox , temp_slider , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = True )
2022-12-21 17:27:31 +01:00
interface . launch ( share = False , server_name = " 0.0.0.0 " )