2022-12-21 17:27:31 +01:00
import re
2023-01-06 05:33:21 +01:00
import time
import glob
2023-01-06 23:56:44 +01:00
from sys import exit
2022-12-21 17:27:31 +01:00
import torch
2023-01-06 23:56:44 +01:00
import argparse
2023-01-07 20:33:43 +01:00
from pathlib import Path
2022-12-21 17:27:31 +01:00
import gradio as gr
import transformers
2023-01-07 03:14:08 +01:00
from html_generator import *
2023-01-11 03:39:50 +01:00
from transformers import AutoTokenizer , AutoModelForCausalLM
2023-01-15 04:39:51 +01:00
import warnings
2022-12-21 17:27:31 +01:00
2023-01-07 03:14:08 +01:00
2023-01-06 23:56:44 +01:00
parser = argparse . ArgumentParser ( )
2023-01-07 00:22:26 +01:00
parser . add_argument ( ' --model ' , type = str , help = ' Name of the model to load by default. ' )
parser . add_argument ( ' --notebook ' , action = ' store_true ' , help = ' Launch the webui in notebook mode, where the output is written to the same text box as the input. ' )
2023-01-08 02:52:46 +01:00
parser . add_argument ( ' --chat ' , action = ' store_true ' , help = ' Launch the webui in chat mode. ' )
2023-01-09 14:58:46 +01:00
parser . add_argument ( ' --cpu ' , action = ' store_true ' , help = ' Use the CPU to generate text. ' )
2023-01-11 03:16:33 +01:00
parser . add_argument ( ' --auto-devices ' , action = ' store_true ' , help = ' Automatically split the model across the available GPU(s) and CPU. ' )
parser . add_argument ( ' --load-in-8bit ' , action = ' store_true ' , help = ' Load the model with 8-bit precision. ' )
2023-01-11 05:10:11 +01:00
parser . add_argument ( ' --no-listen ' , action = ' store_true ' , help = ' Make the webui unreachable from your local network. ' )
2023-01-06 23:56:44 +01:00
args = parser . parse_args ( )
2023-01-15 04:39:51 +01:00
2023-01-06 06:06:59 +01:00
loaded_preset = None
2023-01-07 20:33:43 +01:00
available_models = sorted ( set ( map ( lambda x : str ( x . name ) . replace ( ' .pt ' , ' ' ) , list ( Path ( ' models/ ' ) . glob ( ' * ' ) ) + list ( Path ( ' torch-dumps/ ' ) . glob ( ' * ' ) ) ) ) )
2023-01-07 03:14:08 +01:00
available_models = [ item for item in available_models if not item . endswith ( ' .txt ' ) ]
2023-01-11 05:17:20 +01:00
available_models = sorted ( available_models , key = str . lower )
2023-01-07 20:33:43 +01:00
available_presets = sorted ( set ( map ( lambda x : str ( x . name ) . split ( ' . ' ) [ 0 ] , list ( Path ( ' presets ' ) . glob ( ' *.txt ' ) ) ) ) )
2023-01-06 05:33:21 +01:00
2023-01-15 04:39:51 +01:00
transformers . logging . set_verbosity_error ( )
2022-12-21 17:27:31 +01:00
def load_model ( model_name ) :
2023-01-06 05:41:52 +01:00
print ( f " Loading { model_name } ... " )
2022-12-21 17:27:31 +01:00
t0 = time . time ( )
2023-01-06 05:41:52 +01:00
2023-01-11 03:16:33 +01:00
# Default settings
if not ( args . cpu or args . auto_devices or args . load_in_8bit ) :
if Path ( f " torch-dumps/ { model_name } .pt " ) . exists ( ) :
print ( " Loading in .pt format... " )
model = torch . load ( Path ( f " torch-dumps/ { model_name } .pt " ) )
elif model_name . lower ( ) . startswith ( ( ' gpt-neo ' , ' opt- ' , ' galactica ' ) ) and any ( size in model_name . lower ( ) for size in ( ' 13b ' , ' 20b ' , ' 30b ' ) ) :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , device_map = ' auto ' , load_in_8bit = True )
else :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
# Custom
2023-01-06 06:54:33 +01:00
else :
2023-01-11 03:16:33 +01:00
settings = [ " low_cpu_mem_usage=True " ]
cuda = " "
2023-01-11 03:39:50 +01:00
command = " AutoModelForCausalLM.from_pretrained "
2023-01-11 03:16:33 +01:00
2023-01-09 20:28:04 +01:00
if args . cpu :
2023-01-11 03:16:33 +01:00
settings . append ( " torch_dtype=torch.float32 " )
2023-01-09 20:28:04 +01:00
else :
2023-01-11 03:16:33 +01:00
if args . load_in_8bit :
settings . append ( " device_map= ' auto ' " )
settings . append ( " load_in_8bit=True " )
else :
settings . append ( " torch_dtype=torch.float16 " )
if args . auto_devices :
settings . append ( " device_map= ' auto ' " )
else :
cuda = " .cuda() "
settings = ' , ' . join ( settings )
command = f " { command } (Path(f ' models/ { model_name } ' ), { settings } ) { cuda } "
model = eval ( command )
2022-12-21 17:27:31 +01:00
2023-01-06 06:54:33 +01:00
# Loading the tokenizer
2023-01-11 05:10:11 +01:00
if model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) and Path ( f " models/gpt-j-6B/ " ) . exists ( ) :
2023-01-07 20:33:43 +01:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( " models/gpt-j-6B/ " ) )
2022-12-21 17:27:31 +01:00
else :
2023-01-07 20:33:43 +01:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " models/ { model_name } / " ) )
2022-12-21 17:27:31 +01:00
2023-01-06 06:06:59 +01:00
print ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2022-12-21 17:27:31 +01:00
return model , tokenizer
2023-01-06 06:26:33 +01:00
# Removes empty replies from gpt4chan outputs
2022-12-21 17:27:31 +01:00
def fix_gpt4chan ( s ) :
for i in range ( 10 ) :
s = re . sub ( " --- [0-9]* \n >>[0-9]* \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n * \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n \n \n --- " , " --- " , s )
return s
2023-01-11 05:10:11 +01:00
# Fix the LaTeX equations in GALACTICA
2023-01-07 05:56:21 +01:00
def fix_galactica ( s ) :
s = s . replace ( r ' \ [ ' , r ' $ ' )
s = s . replace ( r ' \ ] ' , r ' $ ' )
2023-01-07 16:13:09 +01:00
s = s . replace ( r ' \ ( ' , r ' $ ' )
s = s . replace ( r ' \ ) ' , r ' $ ' )
s = s . replace ( r ' $$ ' , r ' $ ' )
2023-01-07 05:56:21 +01:00
return s
2023-01-11 05:10:11 +01:00
def generate_html ( s ) :
s = ' \n ' . join ( [ f ' <p style= " margin-bottom: 20px " > { line } </p> ' for line in s . split ( ' \n ' ) ] )
s = f ' <div style= " max-width: 600px; margin-left: auto; margin-right: auto; background-color:#eef2ff; color:#0b0f19; padding:3em; font-size:1.2em; " > { s } </div> '
return s
2023-01-13 18:28:53 +01:00
def generate_reply ( question , tokens , inference_settings , selected_model , eos_token = None ) :
2023-01-06 06:06:59 +01:00
global model , tokenizer , model_name , loaded_preset , preset
2022-12-21 17:27:31 +01:00
if selected_model != model_name :
model_name = selected_model
model = None
2023-01-08 18:37:43 +01:00
tokenizer = None
2023-01-09 14:58:46 +01:00
if not args . cpu :
torch . cuda . empty_cache ( )
2022-12-21 17:27:31 +01:00
model , tokenizer = load_model ( model_name )
2023-01-06 06:06:59 +01:00
if inference_settings != loaded_preset :
2023-01-07 20:33:43 +01:00
with open ( Path ( f ' presets/ { inference_settings } .txt ' ) , ' r ' ) as infile :
2023-01-06 05:33:21 +01:00
preset = infile . read ( )
2023-01-06 06:06:59 +01:00
loaded_preset = inference_settings
2022-12-21 17:27:31 +01:00
2023-01-09 14:58:46 +01:00
if not args . cpu :
torch . cuda . empty_cache ( )
input_ids = tokenizer . encode ( str ( question ) , return_tensors = ' pt ' ) . cuda ( )
cuda = " .cuda() "
else :
input_ids = tokenizer . encode ( str ( question ) , return_tensors = ' pt ' )
cuda = " "
2022-12-21 17:27:31 +01:00
2023-01-09 03:00:38 +01:00
if eos_token is None :
2023-01-09 14:58:46 +01:00
output = eval ( f " model.generate(input_ids, { preset } ) { cuda } " )
2023-01-09 03:00:38 +01:00
else :
2023-01-09 16:56:54 +01:00
n = tokenizer . encode ( eos_token , return_tensors = ' pt ' ) [ 0 ] [ - 1 ]
2023-01-09 14:58:46 +01:00
output = eval ( f " model.generate(input_ids, eos_token_id= { n } , { preset } ) { cuda } " )
2023-01-07 03:14:08 +01:00
2023-01-09 14:58:46 +01:00
reply = tokenizer . decode ( output [ 0 ] , skip_special_tokens = True )
2023-01-11 05:10:11 +01:00
reply = reply . replace ( r ' <|endoftext|> ' , ' ' )
2023-01-07 00:22:26 +01:00
if model_name . lower ( ) . startswith ( ' galactica ' ) :
2023-01-07 05:56:21 +01:00
reply = fix_galactica ( reply )
2023-01-11 05:10:11 +01:00
return reply , reply , generate_html ( reply )
2023-01-07 03:14:08 +01:00
elif model_name . lower ( ) . startswith ( ' gpt4chan ' ) :
2023-01-07 05:56:21 +01:00
reply = fix_gpt4chan ( reply )
2023-01-11 05:10:11 +01:00
return reply , ' Only applicable for galactica models. ' , generate_4chan_html ( reply )
2023-01-07 00:22:26 +01:00
else :
2023-01-11 05:10:11 +01:00
return reply , ' Only applicable for galactica models. ' , generate_html ( reply )
2022-12-21 17:27:31 +01:00
2023-01-06 23:56:44 +01:00
# Choosing the default model
if args . model is not None :
model_name = args . model
else :
2023-01-07 02:05:37 +01:00
if len ( available_models ) == 0 :
2023-01-06 23:56:44 +01:00
print ( " No models are available! Please download at least one. " )
exit ( 0 )
elif len ( available_models ) == 1 :
i = 0
else :
print ( " The following models are available: \n " )
for i , model in enumerate ( available_models ) :
print ( f " { i + 1 } . { model } " )
print ( f " \n Which one do you want to load? 1- { len ( available_models ) } \n " )
i = int ( input ( ) ) - 1
2023-01-09 16:56:54 +01:00
print ( )
2023-01-06 23:56:44 +01:00
model_name = available_models [ i ]
2022-12-21 17:27:31 +01:00
model , tokenizer = load_model ( model_name )
2023-01-06 23:56:44 +01:00
2023-01-09 00:10:31 +01:00
# UI settings
2023-01-07 23:11:21 +01:00
if model_name . lower ( ) . startswith ( ' gpt4chan ' ) :
2022-12-21 17:27:31 +01:00
default_text = " ----- \n --- 865467536 \n Input text \n --- 865467537 \n "
else :
default_text = " Common sense questions and answers \n \n Question: \n Factual answer: "
2023-01-09 00:10:31 +01:00
description = f """
2023-01-07 23:11:21 +01:00
2023-01-07 00:22:26 +01:00
# Text generation lab
Generate text using Large Language Models .
"""
2023-01-09 00:10:31 +01:00
css = " .my-4 { margin-top: 0} .py-6 { padding-top: 2.5rem} "
2023-01-07 00:22:26 +01:00
2023-01-09 00:10:31 +01:00
if args . notebook :
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
2023-01-07 02:05:37 +01:00
with gr . Tab ( ' Raw ' ) :
textbox = gr . Textbox ( value = default_text , lines = 23 )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-07 03:14:08 +01:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-07 00:22:26 +01:00
btn = gr . Button ( " Generate " )
2023-01-07 02:05:37 +01:00
2023-01-13 18:28:53 +01:00
length_slider = gr . Slider ( minimum = 1 , maximum = 2000 , step = 1 , label = ' max_new_tokens ' , value = 200 )
2023-01-07 02:05:37 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
2023-01-13 18:28:53 +01:00
with gr . Column ( ) :
preset_menu = gr . Dropdown ( choices = available_presets , value = " NovelAI-Sphinx Moth " , label = ' Settings preset ' )
2023-01-07 00:22:26 +01:00
2023-01-13 18:00:43 +01:00
btn . click ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ textbox , markdown , html ] , show_progress = True , api_name = " textgen " )
textbox . submit ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ textbox , markdown , html ] , show_progress = True )
2023-01-08 02:52:46 +01:00
elif args . chat :
history = [ ]
2023-01-15 04:39:51 +01:00
# This gets the new line characters right.
def chat_response_cleaner ( text ) :
2023-01-15 03:50:34 +01:00
text = text . replace ( ' \n ' , ' \n \n ' )
text = re . sub ( r " \ n { 3,} " , " \n \n " , text )
text = text . strip ( )
2023-01-15 04:39:51 +01:00
return text
def chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check ) :
text = chat_response_cleaner ( text )
2023-01-15 03:50:34 +01:00
2023-01-08 02:52:46 +01:00
question = context + ' \n \n '
for i in range ( len ( history ) ) :
question + = f " { name1 } : { history [ i ] [ 0 ] [ 3 : - 5 ] . strip ( ) } \n "
question + = f " { name2 } : { history [ i ] [ 1 ] [ 3 : - 5 ] . strip ( ) } \n "
2023-01-15 03:50:34 +01:00
question + = f " { name1 } : { text } \n "
2023-01-08 02:52:46 +01:00
question + = f " { name2 } : "
2023-01-13 19:02:17 +01:00
if check :
reply = generate_reply ( question , tokens , inference_settings , selected_model , eos_token = ' \n ' ) [ 0 ]
reply = reply [ len ( question ) : ] . split ( ' \n ' ) [ 0 ] . strip ( )
else :
reply = generate_reply ( question , tokens , inference_settings , selected_model ) [ 0 ]
2023-01-15 03:26:14 +01:00
reply = reply [ len ( question ) : ]
2023-01-13 19:02:17 +01:00
idx = reply . find ( f " \n { name1 } : " )
if idx != - 1 :
reply = reply [ : idx ]
2023-01-15 04:39:51 +01:00
reply = chat_response_cleaner ( response )
2023-01-13 19:02:17 +01:00
2023-01-08 02:52:46 +01:00
history . append ( ( text , reply ) )
return history
2023-01-15 07:19:09 +01:00
def remove_last_message ( ) :
history . pop ( )
return history
2023-01-08 02:52:46 +01:00
def clear ( ) :
global history
history = [ ]
2023-01-13 14:12:47 +01:00
if ' pygmalion ' in model_name . lower ( ) :
2023-01-13 18:00:43 +01:00
context_str = " This is a conversation between two people. \n <START> "
2023-01-13 14:12:47 +01:00
name1_str = " You "
name2_str = " Kawaii "
else :
context_str = " This is a conversation between two people. "
name1_str = " Person 1 "
name2_str = " Person 2 "
2023-01-09 00:10:31 +01:00
with gr . Blocks ( css = css + " .h- \ [40vh \ ] { height: 50vh} " , analytics_enabled = False ) as interface :
gr . Markdown ( description )
2023-01-09 21:23:43 +01:00
with gr . Row ( ) :
2023-01-08 02:52:46 +01:00
with gr . Column ( ) :
2023-01-13 18:28:53 +01:00
length_slider = gr . Slider ( minimum = 1 , maximum = 2000 , step = 1 , label = ' max_new_tokens ' , value = 200 )
2023-01-09 21:23:43 +01:00
with gr . Row ( ) :
2023-01-08 02:52:46 +01:00
with gr . Column ( ) :
2023-01-13 18:00:43 +01:00
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
2023-01-08 02:52:46 +01:00
with gr . Column ( ) :
2023-01-13 18:28:53 +01:00
preset_menu = gr . Dropdown ( choices = available_presets , value = " NovelAI-Sphinx Moth " , label = ' Settings preset ' )
2023-01-13 14:12:47 +01:00
name1 = gr . Textbox ( value = name1_str , lines = 1 , label = ' Your name ' )
name2 = gr . Textbox ( value = name2_str , lines = 1 , label = ' Bot \' s name ' )
context = gr . Textbox ( value = context_str , lines = 2 , label = ' Context ' )
2023-01-13 19:02:17 +01:00
with gr . Row ( ) :
check = gr . Checkbox ( value = True , label = ' Stop generating at new line character? ' )
2023-01-08 02:52:46 +01:00
with gr . Column ( ) :
display1 = gr . Chatbot ( )
2023-01-08 02:55:54 +01:00
textbox = gr . Textbox ( lines = 2 , label = ' Input ' )
2023-01-08 02:52:46 +01:00
btn = gr . Button ( " Generate " )
2023-01-15 07:19:09 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
btn3 = gr . Button ( " Remove last message " )
with gr . Column ( ) :
btn2 = gr . Button ( " Clear history " )
2023-01-08 02:52:46 +01:00
2023-01-13 19:02:17 +01:00
btn . click ( chatbot_wrapper , [ textbox , length_slider , preset_menu , model_menu , name1 , name2 , context , check ] , display1 , show_progress = True , api_name = " textgen " )
textbox . submit ( chatbot_wrapper , [ textbox , length_slider , preset_menu , model_menu , name1 , name2 , context , check ] , display1 , show_progress = True )
2023-01-15 07:19:09 +01:00
btn3 . click ( remove_last_message , [ ] , display1 , show_progress = False )
2023-01-08 02:52:46 +01:00
btn2 . click ( clear )
2023-01-08 05:10:02 +01:00
btn . click ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-08 05:33:45 +01:00
textbox . submit ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-08 05:10:02 +01:00
btn2 . click ( lambda x : " " , display1 , display1 )
2023-01-07 00:22:26 +01:00
else :
2023-01-11 05:33:57 +01:00
2023-01-13 18:28:53 +01:00
def continue_wrapper ( question , tokens , inference_settings , selected_model ) :
a , b , c = generate_reply ( question , tokens , inference_settings , selected_model )
2023-01-11 05:33:57 +01:00
return a , a , b , c
2023-01-09 00:10:31 +01:00
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
2023-01-07 02:05:37 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Input ' )
2023-01-13 18:28:53 +01:00
length_slider = gr . Slider ( minimum = 1 , maximum = 2000 , step = 1 , label = ' max_new_tokens ' , value = 200 )
preset_menu = gr . Dropdown ( choices = available_presets , value = " NovelAI-Sphinx Moth " , label = ' Settings preset ' )
2023-01-07 02:05:37 +01:00
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
btn = gr . Button ( " Generate " )
2023-01-11 05:33:57 +01:00
cont = gr . Button ( " Continue " )
2023-01-07 02:05:37 +01:00
with gr . Column ( ) :
with gr . Tab ( ' Raw ' ) :
2023-01-11 05:36:11 +01:00
output_textbox = gr . Textbox ( lines = 15 , label = ' Output ' )
2023-01-07 02:05:37 +01:00
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-07 03:14:08 +01:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-07 02:05:37 +01:00
2023-01-13 18:00:43 +01:00
btn . click ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = True , api_name = " textgen " )
cont . click ( continue_wrapper , [ output_textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , textbox , markdown , html ] , show_progress = True )
textbox . submit ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = True )
2022-12-21 17:27:31 +01:00
2023-01-11 05:10:11 +01:00
if args . no_listen :
2023-01-09 23:05:36 +01:00
interface . launch ( share = False )
2023-01-11 05:10:11 +01:00
else :
interface . launch ( share = False , server_name = " 0.0.0.0 " )