2022-12-21 17:27:31 +01:00
import re
2023-01-22 04:02:46 +01:00
import gc
2023-01-06 05:33:21 +01:00
import time
import glob
2022-12-21 17:27:31 +01:00
import torch
2023-01-06 23:56:44 +01:00
import argparse
2023-01-15 19:23:41 +01:00
import json
2023-01-22 04:02:46 +01:00
from sys import exit
2023-01-07 20:33:43 +01:00
from pathlib import Path
2022-12-21 17:27:31 +01:00
import gradio as gr
2023-01-15 04:39:51 +01:00
import warnings
2023-01-19 16:20:57 +01:00
from tqdm import tqdm
2023-01-22 04:02:46 +01:00
import transformers
from transformers import AutoTokenizer , AutoModelForCausalLM
from modules . html_generator import *
from modules . ui import *
2023-01-25 14:17:55 +01:00
from modules . stopping_criteria import _SentinelTokenStoppingCriteria
2022-12-21 17:27:31 +01:00
2023-01-15 19:23:41 +01:00
transformers . logging . set_verbosity_error ( )
2023-01-06 23:56:44 +01:00
parser = argparse . ArgumentParser ( )
2023-01-07 00:22:26 +01:00
parser . add_argument ( ' --model ' , type = str , help = ' Name of the model to load by default. ' )
2023-01-16 14:10:09 +01:00
parser . add_argument ( ' --notebook ' , action = ' store_true ' , help = ' Launch the web UI in notebook mode, where the output is written to the same text box as the input. ' )
parser . add_argument ( ' --chat ' , action = ' store_true ' , help = ' Launch the web UI in chat mode. ' )
2023-01-25 23:39:36 +01:00
parser . add_argument ( ' --cai-chat ' , action = ' store_true ' , help = ' Launch the web UI in chat mode with a style similar to Character.AI \' s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot \' s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture. ' )
2023-01-09 14:58:46 +01:00
parser . add_argument ( ' --cpu ' , action = ' store_true ' , help = ' Use the CPU to generate text. ' )
2023-01-11 03:16:33 +01:00
parser . add_argument ( ' --load-in-8bit ' , action = ' store_true ' , help = ' Load the model with 8-bit precision. ' )
2023-01-19 15:09:24 +01:00
parser . add_argument ( ' --auto-devices ' , action = ' store_true ' , help = ' Automatically split the model across the available GPU(s) and CPU. ' )
parser . add_argument ( ' --disk ' , action = ' store_true ' , help = ' If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. ' )
2023-01-21 06:48:06 +01:00
parser . add_argument ( ' --disk-cache-dir ' , type = str , help = ' Directory to save the disk cache to. Defaults to " cache/ " . ' )
2023-01-21 04:33:41 +01:00
parser . add_argument ( ' --gpu-memory ' , type = int , help = ' Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. ' )
2023-01-21 07:05:55 +01:00
parser . add_argument ( ' --cpu-memory ' , type = int , help = ' Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99. ' )
2023-01-22 20:19:11 +01:00
parser . add_argument ( ' --no-stream ' , action = ' store_true ' , help = ' Don \' t stream the text output in real time. This improves the text generation performance. ' )
2023-01-16 20:35:45 +01:00
parser . add_argument ( ' --settings ' , type = str , help = ' Load the default interface settings from this json file. See settings-template.json for an example. ' )
2023-01-21 03:45:16 +01:00
parser . add_argument ( ' --listen ' , action = ' store_true ' , help = ' Make the web UI reachable from your local network. ' )
2023-01-19 21:31:29 +01:00
parser . add_argument ( ' --share ' , action = ' store_true ' , help = ' Create a public URL. This is useful for running the web UI on Google Colab or similar. ' )
2023-01-06 23:56:44 +01:00
args = parser . parse_args ( )
2023-01-15 04:39:51 +01:00
2023-01-22 20:19:11 +01:00
if ( args . chat or args . cai_chat ) and not args . no_stream :
2023-01-25 18:37:41 +01:00
print ( " Warning: chat mode currently becomes somewhat slower with text streaming on. \n Consider starting the web UI with the --no-stream option. \n " )
2023-01-22 20:19:11 +01:00
2023-01-15 19:23:41 +01:00
settings = {
' max_new_tokens ' : 200 ,
' max_new_tokens_min ' : 1 ,
' max_new_tokens_max ' : 2000 ,
' preset ' : ' NovelAI-Sphinx Moth ' ,
' name1 ' : ' Person 1 ' ,
' name2 ' : ' Person 2 ' ,
' context ' : ' This is a conversation between two people. ' ,
' prompt ' : ' Common sense questions and answers \n \n Question: \n Factual answer: ' ,
' prompt_gpt4chan ' : ' ----- \n --- 865467536 \n Input text \n --- 865467537 \n ' ,
' stop_at_newline ' : True ,
2023-01-22 21:17:35 +01:00
' history_size ' : 0 ,
2023-01-20 21:03:09 +01:00
' history_size_min ' : 0 ,
' history_size_max ' : 64 ,
2023-01-19 22:58:45 +01:00
' preset_pygmalion ' : ' Pygmalion ' ,
' name1_pygmalion ' : ' You ' ,
' name2_pygmalion ' : ' Kawaii ' ,
2023-01-22 04:49:59 +01:00
' context_pygmalion ' : " Kawaii ' s persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes. \n <START> " ,
2023-01-19 20:46:46 +01:00
' stop_at_newline_pygmalion ' : False ,
2023-01-15 19:23:41 +01:00
}
2023-01-16 20:35:45 +01:00
if args . settings is not None and Path ( args . settings ) . exists ( ) :
with open ( Path ( args . settings ) , ' r ' ) as f :
2023-01-15 19:23:41 +01:00
new_settings = json . load ( f )
2023-01-16 20:35:45 +01:00
for item in new_settings :
if item in settings :
settings [ item ] = new_settings [ item ]
2023-01-15 04:39:51 +01:00
2022-12-21 17:27:31 +01:00
def load_model ( model_name ) :
2023-01-06 05:41:52 +01:00
print ( f " Loading { model_name } ... " )
2022-12-21 17:27:31 +01:00
t0 = time . time ( )
2023-01-06 05:41:52 +01:00
2023-01-11 03:16:33 +01:00
# Default settings
2023-01-21 03:45:16 +01:00
if not ( args . cpu or args . load_in_8bit or args . auto_devices or args . disk or args . gpu_memory is not None ) :
2023-01-11 03:16:33 +01:00
if Path ( f " torch-dumps/ { model_name } .pt " ) . exists ( ) :
print ( " Loading in .pt format... " )
model = torch . load ( Path ( f " torch-dumps/ { model_name } .pt " ) )
elif model_name . lower ( ) . startswith ( ( ' gpt-neo ' , ' opt- ' , ' galactica ' ) ) and any ( size in model_name . lower ( ) for size in ( ' 13b ' , ' 20b ' , ' 30b ' ) ) :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , device_map = ' auto ' , load_in_8bit = True )
else :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
# Custom
2023-01-06 06:54:33 +01:00
else :
2023-01-11 03:16:33 +01:00
settings = [ " low_cpu_mem_usage=True " ]
2023-01-11 03:39:50 +01:00
command = " AutoModelForCausalLM.from_pretrained "
2023-01-11 03:16:33 +01:00
2023-01-09 20:28:04 +01:00
if args . cpu :
2023-01-11 03:16:33 +01:00
settings . append ( " torch_dtype=torch.float32 " )
2023-01-09 20:28:04 +01:00
else :
2023-01-16 03:01:51 +01:00
settings . append ( " device_map= ' auto ' " )
2023-01-21 03:45:16 +01:00
if args . gpu_memory is not None :
2023-01-21 04:33:41 +01:00
if args . cpu_memory is not None :
settings . append ( f " max_memory= {{ 0: ' { args . gpu_memory } GiB ' , ' cpu ' : ' { args . cpu_memory } GiB ' }} " )
2023-01-20 23:05:43 +01:00
else :
2023-01-21 04:25:34 +01:00
settings . append ( f " max_memory= {{ 0: ' { args . gpu_memory } GiB ' , ' cpu ' : ' 99GiB ' }} " )
2023-01-19 15:09:24 +01:00
if args . disk :
2023-01-20 23:05:43 +01:00
if args . disk_cache_dir is not None :
2023-01-21 19:04:13 +01:00
settings . append ( f " offload_folder= ' { args . disk_cache_dir } ' " )
2023-01-20 23:05:43 +01:00
else :
settings . append ( " offload_folder= ' cache ' " )
2023-01-16 03:01:51 +01:00
if args . load_in_8bit :
2023-01-11 03:16:33 +01:00
settings . append ( " load_in_8bit=True " )
else :
settings . append ( " torch_dtype=torch.float16 " )
2023-01-16 20:35:45 +01:00
settings = ' , ' . join ( set ( settings ) )
2023-01-16 03:01:51 +01:00
command = f " { command } (Path(f ' models/ { model_name } ' ), { settings } ) "
2023-01-11 03:16:33 +01:00
model = eval ( command )
2022-12-21 17:27:31 +01:00
2023-01-06 06:54:33 +01:00
# Loading the tokenizer
2023-01-11 05:10:11 +01:00
if model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) and Path ( f " models/gpt-j-6B/ " ) . exists ( ) :
2023-01-07 20:33:43 +01:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( " models/gpt-j-6B/ " ) )
2022-12-21 17:27:31 +01:00
else :
2023-01-07 20:33:43 +01:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " models/ { model_name } / " ) )
2023-01-16 17:43:23 +01:00
tokenizer . truncation_side = ' left '
2022-12-21 17:27:31 +01:00
2023-01-06 06:06:59 +01:00
print ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2022-12-21 17:27:31 +01:00
return model , tokenizer
2023-01-06 06:26:33 +01:00
# Removes empty replies from gpt4chan outputs
2022-12-21 17:27:31 +01:00
def fix_gpt4chan ( s ) :
for i in range ( 10 ) :
s = re . sub ( " --- [0-9]* \n >>[0-9]* \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n * \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n \n \n --- " , " --- " , s )
return s
2023-01-16 20:35:45 +01:00
# Fix the LaTeX equations in galactica
2023-01-07 05:56:21 +01:00
def fix_galactica ( s ) :
s = s . replace ( r ' \ [ ' , r ' $ ' )
s = s . replace ( r ' \ ] ' , r ' $ ' )
2023-01-07 16:13:09 +01:00
s = s . replace ( r ' \ ( ' , r ' $ ' )
s = s . replace ( r ' \ ) ' , r ' $ ' )
s = s . replace ( r ' $$ ' , r ' $ ' )
2023-01-07 05:56:21 +01:00
return s
2023-01-25 14:17:55 +01:00
def encode ( prompt , tokens_to_generate = 0 , add_special_tokens = True ) :
2023-01-23 17:36:01 +01:00
if args . cpu :
2023-01-25 14:17:55 +01:00
input_ids = tokenizer . encode ( str ( prompt ) , return_tensors = ' pt ' , truncation = True , max_length = 2048 - tokens_to_generate , add_special_tokens = add_special_tokens )
2023-01-23 17:36:01 +01:00
else :
2023-01-18 00:16:23 +01:00
torch . cuda . empty_cache ( )
2023-01-25 14:17:55 +01:00
input_ids = tokenizer . encode ( str ( prompt ) , return_tensors = ' pt ' , truncation = True , max_length = 2048 - tokens_to_generate , add_special_tokens = add_special_tokens ) . cuda ( )
2023-01-18 00:16:23 +01:00
return input_ids
2023-01-19 14:43:05 +01:00
def decode ( output_ids ) :
reply = tokenizer . decode ( output_ids , skip_special_tokens = True )
reply = reply . replace ( r ' <|endoftext|> ' , ' ' )
return reply
def formatted_outputs ( reply , model_name ) :
2023-01-19 18:57:01 +01:00
if not ( args . chat or args . cai_chat ) :
if model_name . lower ( ) . startswith ( ' galactica ' ) :
reply = fix_galactica ( reply )
return reply , reply , generate_basic_html ( reply )
2023-01-22 02:13:01 +01:00
elif model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) :
2023-01-19 18:57:01 +01:00
reply = fix_gpt4chan ( reply )
return reply , ' Only applicable for GALACTICA models. ' , generate_4chan_html ( reply )
else :
return reply , ' Only applicable for GALACTICA models. ' , generate_basic_html ( reply )
2023-01-19 14:43:05 +01:00
else :
2023-01-19 18:57:01 +01:00
return reply
2023-01-19 14:43:05 +01:00
2023-01-25 14:17:55 +01:00
def generate_reply ( question , tokens , inference_settings , selected_model , eos_token = None , stopping_string = None ) :
2023-01-06 06:06:59 +01:00
global model , tokenizer , model_name , loaded_preset , preset
2022-12-21 17:27:31 +01:00
if selected_model != model_name :
model_name = selected_model
2023-01-21 03:45:16 +01:00
model = tokenizer = None
2023-01-09 14:58:46 +01:00
if not args . cpu :
2023-01-19 16:01:58 +01:00
gc . collect ( )
2023-01-09 14:58:46 +01:00
torch . cuda . empty_cache ( )
2022-12-21 17:27:31 +01:00
model , tokenizer = load_model ( model_name )
2023-01-06 06:06:59 +01:00
if inference_settings != loaded_preset :
2023-01-07 20:33:43 +01:00
with open ( Path ( f ' presets/ { inference_settings } .txt ' ) , ' r ' ) as infile :
2023-01-06 05:33:21 +01:00
preset = infile . read ( )
2023-01-06 06:06:59 +01:00
loaded_preset = inference_settings
2022-12-21 17:27:31 +01:00
2023-01-19 04:41:57 +01:00
cuda = " " if args . cpu else " .cuda() "
2023-01-26 02:27:04 +01:00
n = tokenizer . eos_token_id if eos_token is None else tokenizer . encode ( eos_token , return_tensors = ' pt ' ) [ 0 ] [ - 1 ]
2023-01-20 04:45:02 +01:00
input_ids = encode ( question , tokens )
2023-01-25 14:17:55 +01:00
# The stopping_criteria code below was copied from
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
if stopping_string is not None :
t = encode ( stopping_string , 0 , add_special_tokens = False )
stopping_criteria_list = transformers . StoppingCriteriaList ( [
_SentinelTokenStoppingCriteria (
2023-01-25 23:37:44 +01:00
sentinel_token_ids = t ,
starting_idx = len ( input_ids [ 0 ] )
)
2023-01-25 14:17:55 +01:00
] )
else :
stopping_criteria_list = None
2023-01-19 14:43:05 +01:00
# Generate the entire reply at once
if args . no_stream :
2023-01-23 00:07:19 +01:00
t0 = time . time ( )
2023-01-25 14:17:55 +01:00
output = eval ( f " model.generate(input_ids, eos_token_id= { n } , stopping_criteria=stopping_criteria_list, { preset } ) { cuda } " )
2023-01-19 14:43:05 +01:00
reply = decode ( output [ 0 ] )
2023-01-23 00:07:19 +01:00
t1 = time . time ( )
print ( f " Output generated in { ( t1 - t0 ) : .2f } seconds ( { ( len ( output [ 0 ] ) - len ( input_ids [ 0 ] ) ) / ( t1 - t0 ) : .2f } it/s) " )
2023-01-19 14:43:05 +01:00
yield formatted_outputs ( reply , model_name )
# Generate the reply 1 token at a time
else :
2023-01-19 15:09:24 +01:00
yield formatted_outputs ( question , model_name )
2023-01-25 14:38:26 +01:00
preset = preset . replace ( ' max_new_tokens=tokens ' , ' max_new_tokens=8 ' )
for i in tqdm ( range ( tokens / / 8 + 1 ) ) :
2023-01-25 14:17:55 +01:00
output = eval ( f " model.generate(input_ids, eos_token_id= { n } , stopping_criteria=stopping_criteria_list, { preset } ) { cuda } " )
2023-01-19 14:43:05 +01:00
reply = decode ( output [ 0 ] )
yield formatted_outputs ( reply , model_name )
2023-01-19 03:56:42 +01:00
input_ids = output
2023-01-26 02:27:04 +01:00
if output [ 0 ] [ - 1 ] == n :
break
2023-01-19 01:37:21 +01:00
2023-01-22 04:49:59 +01:00
def get_available_models ( ) :
return sorted ( set ( [ item . replace ( ' .pt ' , ' ' ) for item in map ( lambda x : str ( x . name ) , list ( Path ( ' models/ ' ) . glob ( ' * ' ) ) + list ( Path ( ' torch-dumps/ ' ) . glob ( ' * ' ) ) ) if not item . endswith ( ' .txt ' ) ] ) , key = str . lower )
def get_available_presets ( ) :
return sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' presets ' ) . glob ( ' *.txt ' ) ) ) , key = str . lower )
def get_available_characters ( ) :
return [ " None " ] + sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' characters ' ) . glob ( ' *.json ' ) ) ) , key = str . lower )
available_models = get_available_models ( )
available_presets = get_available_presets ( )
available_characters = get_available_characters ( )
2023-01-06 23:56:44 +01:00
# Choosing the default model
if args . model is not None :
model_name = args . model
else :
2023-01-07 02:05:37 +01:00
if len ( available_models ) == 0 :
2023-01-06 23:56:44 +01:00
print ( " No models are available! Please download at least one. " )
exit ( 0 )
elif len ( available_models ) == 1 :
i = 0
else :
print ( " The following models are available: \n " )
for i , model in enumerate ( available_models ) :
print ( f " { i + 1 } . { model } " )
print ( f " \n Which one do you want to load? 1- { len ( available_models ) } \n " )
i = int ( input ( ) ) - 1
2023-01-09 16:56:54 +01:00
print ( )
2023-01-06 23:56:44 +01:00
model_name = available_models [ i ]
2022-12-21 17:27:31 +01:00
model , tokenizer = load_model ( model_name )
2023-01-22 04:49:59 +01:00
loaded_preset = None
2023-01-06 23:56:44 +01:00
2023-01-09 00:10:31 +01:00
# UI settings
2023-01-22 04:49:59 +01:00
default_text = settings [ ' prompt_gpt4chan ' ] if model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) else settings [ ' prompt ' ]
2023-01-15 19:23:41 +01:00
description = f " \n \n # Text generation lab \n Generate text using Large Language Models. \n "
2023-01-22 04:02:46 +01:00
css = " .my-4 { margin-top: 0} .py-6 { padding-top: 2.5rem} #refresh-button { flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label { min-height: 0} "
2023-01-22 04:49:59 +01:00
2023-01-19 02:44:47 +01:00
if args . chat or args . cai_chat :
2023-01-08 02:52:46 +01:00
history = [ ]
2023-01-19 20:46:46 +01:00
character = None
2023-01-08 02:52:46 +01:00
2023-01-15 04:39:51 +01:00
# This gets the new line characters right.
2023-01-18 23:06:50 +01:00
def clean_chat_message ( text ) :
2023-01-15 03:50:34 +01:00
text = text . replace ( ' \n ' , ' \n \n ' )
text = re . sub ( r " \ n { 3,} " , " \n \n " , text )
text = text . strip ( )
2023-01-15 04:39:51 +01:00
return text
2023-01-20 21:03:09 +01:00
def generate_chat_prompt ( text , tokens , name1 , name2 , context , history_size ) :
2023-01-18 23:06:50 +01:00
text = clean_chat_message ( text )
2023-01-15 03:50:34 +01:00
2023-01-20 05:54:38 +01:00
rows = [ f " { context . strip ( ) } \n " ]
2023-01-18 00:16:23 +01:00
i = len ( history ) - 1
2023-01-20 21:03:09 +01:00
count = 0
2023-01-18 00:16:23 +01:00
while i > = 0 and len ( encode ( ' ' . join ( rows ) , tokens ) [ 0 ] ) < 2048 - tokens :
rows . insert ( 1 , f " { name2 } : { history [ i ] [ 1 ] . strip ( ) } \n " )
2023-01-20 21:03:09 +01:00
count + = 1
2023-01-22 04:35:42 +01:00
if not ( history [ i ] [ 0 ] == ' <|BEGIN-VISIBLE-CHAT|> ' ) :
2023-01-20 05:54:38 +01:00
rows . insert ( 1 , f " { name1 } : { history [ i ] [ 0 ] . strip ( ) } \n " )
2023-01-20 21:03:09 +01:00
count + = 1
2023-01-18 00:16:23 +01:00
i - = 1
2023-01-20 21:03:09 +01:00
if history_size != 0 and count > = history_size :
break
2023-01-18 00:16:23 +01:00
rows . append ( f " { name1 } : { text } \n " )
rows . append ( f " { name2 } : " )
while len ( rows ) > 3 and len ( encode ( ' ' . join ( rows ) , tokens ) [ 0 ] ) > = 2048 - tokens :
rows . pop ( 1 )
rows . pop ( 1 )
question = ' ' . join ( rows )
2023-01-18 23:06:50 +01:00
return question
2023-01-08 02:52:46 +01:00
2023-01-21 19:04:13 +01:00
def remove_example_dialogue_from_history ( history ) :
_history = copy . deepcopy ( history )
for i in range ( len ( _history ) ) :
if ' <|BEGIN-VISIBLE-CHAT|> ' in _history [ i ] [ 0 ] :
_history [ i ] [ 0 ] = _history [ i ] [ 0 ] . replace ( ' <|BEGIN-VISIBLE-CHAT|> ' , ' ' )
_history = _history [ i : ]
break
return _history
2023-01-20 21:03:09 +01:00
def chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
question = generate_chat_prompt ( text , tokens , name1 , name2 , context , history_size )
2023-01-19 02:08:23 +01:00
history . append ( [ ' ' , ' ' ] )
2023-01-18 23:06:50 +01:00
eos_token = ' \n ' if check else None
2023-01-25 14:17:55 +01:00
for reply in generate_reply ( question , tokens , inference_settings , selected_model , eos_token = eos_token , stopping_string = f " \n { name1 } : " ) :
2023-01-19 01:51:18 +01:00
next_character_found = False
2023-01-18 23:06:50 +01:00
2023-01-19 23:59:34 +01:00
previous_idx = [ m . start ( ) for m in re . finditer ( f " (^| \n ) { name2 } : " , question ) ]
2023-01-19 18:57:01 +01:00
idx = [ m . start ( ) for m in re . finditer ( f " (^| \n ) { name2 } : " , reply ) ]
idx = idx [ len ( previous_idx ) - 1 ]
2023-01-19 23:59:34 +01:00
2023-01-19 18:57:01 +01:00
reply = reply [ idx + len ( f " \n { name2 } : " ) : ]
2023-01-18 23:06:50 +01:00
if check :
2023-01-19 18:57:01 +01:00
reply = reply . split ( ' \n ' ) [ 0 ] . strip ( )
2023-01-18 23:06:50 +01:00
else :
idx = reply . find ( f " \n { name1 } : " )
if idx != - 1 :
reply = reply [ : idx ]
2023-01-19 01:51:18 +01:00
next_character_found = True
2023-01-18 23:06:50 +01:00
reply = clean_chat_message ( reply )
history [ - 1 ] = [ text , reply ]
2023-01-19 14:43:05 +01:00
if next_character_found :
break
2023-01-18 23:06:50 +01:00
# Prevent the chat log from flashing if something like "\nYo" is generated just
# before "\nYou:" is completed
tmp = f " \n { name1 } : "
2023-01-19 01:51:18 +01:00
next_character_substring_found = False
2023-01-19 14:43:05 +01:00
for j in range ( 1 , len ( tmp ) ) :
2023-01-18 23:06:50 +01:00
if reply [ - j : ] == tmp [ : j ] :
2023-01-19 01:51:18 +01:00
next_character_substring_found = True
2023-01-18 23:06:50 +01:00
2023-01-19 01:51:18 +01:00
if not next_character_substring_found :
2023-01-21 19:04:13 +01:00
yield remove_example_dialogue_from_history ( history )
2023-01-08 02:52:46 +01:00
2023-01-21 19:04:13 +01:00
yield remove_example_dialogue_from_history ( history )
2023-01-19 01:51:18 +01:00
2023-01-20 21:03:09 +01:00
def cai_chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
for history in chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
2023-01-19 20:46:46 +01:00
yield generate_chat_html ( history , name1 , name2 , character )
2023-01-15 16:20:04 +01:00
2023-01-22 06:19:58 +01:00
def regenerate_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
last = history . pop ( )
text = last [ 0 ]
if args . cai_chat :
for i in cai_chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
yield i
else :
for i in chatbot_wrapper ( text , tokens , inference_settings , selected_model , name1 , name2 , context , check , history_size ) :
yield i
2023-01-15 16:20:04 +01:00
def remove_last_message ( name1 , name2 ) :
2023-01-22 23:40:22 +01:00
last = history . pop ( )
2023-01-22 17:10:36 +01:00
_history = remove_example_dialogue_from_history ( history )
2023-01-15 16:20:04 +01:00
if args . cai_chat :
2023-01-22 23:40:22 +01:00
return generate_chat_html ( _history , name1 , name2 , character ) , last [ 0 ]
2023-01-15 16:20:04 +01:00
else :
2023-01-22 23:40:22 +01:00
return _history , last [ 0 ]
2023-01-15 07:19:09 +01:00
2023-01-15 16:20:04 +01:00
def clear_html ( ) :
2023-01-19 20:46:46 +01:00
return generate_chat_html ( [ ] , " " , " " , character )
2023-01-15 16:20:04 +01:00
2023-01-25 14:52:35 +01:00
def clear_chat_log ( _character , name1 , name2 ) :
global history
if _character != ' None ' :
load_character ( _character , name1 , name2 )
else :
history = [ ]
_history = remove_example_dialogue_from_history ( history )
if args . cai_chat :
return generate_chat_html ( _history , name1 , name2 , character )
else :
return _history
2023-01-19 18:03:47 +01:00
def redraw_html ( name1 , name2 ) :
global history
2023-01-22 06:32:54 +01:00
_history = remove_example_dialogue_from_history ( history )
return generate_chat_html ( _history , name1 , name2 , character )
2023-01-19 18:03:47 +01:00
2023-01-23 13:45:10 +01:00
def tokenize_dialogue ( dialogue , name1 , name2 ) :
2023-01-21 06:48:06 +01:00
dialogue = re . sub ( ' <START> ' , ' ' , dialogue )
dialogue = re . sub ( ' ( \n |^)[Aa]non: ' , ' \\ 1You: ' , dialogue )
idx = [ m . start ( ) for m in re . finditer ( f " (^| \n )( { name1 } | { name2 } ): " , dialogue ) ]
messages = [ ]
for i in range ( len ( idx ) - 1 ) :
messages . append ( dialogue [ idx [ i ] : idx [ i + 1 ] ] . strip ( ) )
2023-01-23 18:28:02 +01:00
messages . append ( dialogue [ idx [ - 1 ] : ] . strip ( ) )
2023-01-21 06:48:06 +01:00
history = [ ]
entry = [ ' ' , ' ' ]
for i in messages :
if i . startswith ( f ' { name1 } : ' ) :
entry [ 0 ] = i [ len ( f ' { name1 } : ' ) : ] . strip ( )
elif i . startswith ( f ' { name2 } : ' ) :
entry [ 1 ] = i [ len ( f ' { name2 } : ' ) : ] . strip ( )
if not ( len ( entry [ 0 ] ) == 0 and len ( entry [ 1 ] ) == 0 ) :
history . append ( entry )
entry = [ ' ' , ' ' ]
2023-01-23 18:28:02 +01:00
2023-01-21 06:48:06 +01:00
return history
2023-01-23 13:45:10 +01:00
def save_history ( ) :
if not Path ( ' logs ' ) . exists ( ) :
Path ( ' logs ' ) . mkdir ( )
with open ( Path ( ' logs/conversation.json ' ) , ' w ' ) as f :
f . write ( json . dumps ( { ' data ' : history } , indent = 2 ) )
return Path ( ' logs/conversation.json ' )
2023-01-25 19:45:25 +01:00
def upload_history ( file , name1 , name2 ) :
2023-01-23 13:45:10 +01:00
global history
file = file . decode ( ' utf-8 ' )
try :
2023-01-23 19:29:01 +01:00
j = json . loads ( file )
if ' data ' in j :
history = j [ ' data ' ]
# Compatibility with Pygmalion AI's official web UI
elif ' chat ' in j :
history = [ ' : ' . join ( x . split ( ' : ' ) [ 1 : ] ) . strip ( ) for x in j [ ' chat ' ] ]
if len ( j [ ' chat ' ] ) > 0 and j [ ' chat ' ] [ 0 ] . startswith ( f ' { name2 } : ' ) :
history = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , history [ 0 ] ] ] + [ [ history [ i ] , history [ i + 1 ] ] for i in range ( 1 , len ( history ) - 1 , 2 ) ]
else :
history = [ [ history [ i ] , history [ i + 1 ] ] for i in range ( 0 , len ( history ) - 1 , 2 ) ]
2023-01-23 13:45:10 +01:00
except :
history = tokenize_dialogue ( file , name1 , name2 )
2023-01-19 20:46:46 +01:00
def load_character ( _character , name1 , name2 ) :
global history , character
context = " "
history = [ ]
if _character != ' None ' :
character = _character
with open ( Path ( f ' characters/ { _character } .json ' ) , ' r ' ) as f :
data = json . loads ( f . read ( ) )
name2 = data [ ' char_name ' ]
if ' char_persona ' in data and data [ ' char_persona ' ] != ' ' :
context + = f " { data [ ' char_name ' ] } ' s Persona: { data [ ' char_persona ' ] } \n "
if ' world_scenario ' in data and data [ ' world_scenario ' ] != ' ' :
context + = f " Scenario: { data [ ' world_scenario ' ] } \n "
2023-01-19 23:04:54 +01:00
context = f " { context . strip ( ) } \n <START> \n "
2023-01-19 20:46:46 +01:00
if ' example_dialogue ' in data and data [ ' example_dialogue ' ] != ' ' :
2023-01-23 13:45:10 +01:00
history = tokenize_dialogue ( data [ ' example_dialogue ' ] , name1 , name2 )
2023-01-21 06:48:06 +01:00
if ' char_greeting ' in data and len ( data [ ' char_greeting ' ] . strip ( ) ) > 0 :
history + = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , data [ ' char_greeting ' ] ] ]
else :
history + = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , " Hello there! " ] ]
2023-01-19 20:46:46 +01:00
else :
character = None
context = settings [ ' context_pygmalion ' ]
name2 = settings [ ' name2_pygmalion ' ]
2023-01-21 19:04:13 +01:00
_history = remove_example_dialogue_from_history ( history )
2023-01-19 20:46:46 +01:00
if args . cai_chat :
2023-01-21 19:04:13 +01:00
return name2 , context , generate_chat_html ( _history , name1 , name2 , character )
2023-01-19 20:46:46 +01:00
else :
2023-01-21 19:04:13 +01:00
return name2 , context , _history
2023-01-19 20:46:46 +01:00
2023-01-25 19:45:25 +01:00
def upload_character ( file , name1 , name2 ) :
global history
file = file . decode ( ' utf-8 ' )
data = json . loads ( file )
outfile_name = data [ " char_name " ]
i = 1
while Path ( f ' characters/ { outfile_name } .json ' ) . exists ( ) :
outfile_name = f ' { data [ " char_name " ] } _ { i : 03d } '
i + = 1
with open ( Path ( f ' characters/ { outfile_name } .json ' ) , ' w ' ) as f :
f . write ( file )
print ( f ' New character saved to " characters/ { outfile_name } .json " . ' )
return outfile_name
2023-01-19 20:46:46 +01:00
suffix = ' _pygmalion ' if ' pygmalion ' in model_name . lower ( ) else ' '
2023-01-15 22:16:46 +01:00
with gr . Blocks ( css = css + " .h- \ [40vh \ ] { height: 66.67vh} .gradio-container { max-width: 800px; margin-left: auto; margin-right: auto} " , analytics_enabled = False ) as interface :
if args . cai_chat :
2023-01-19 20:46:46 +01:00
display1 = gr . HTML ( value = generate_chat_html ( [ ] , " " , " " , character ) )
2023-01-15 22:16:46 +01:00
else :
display1 = gr . Chatbot ( )
2023-01-23 18:04:01 +01:00
textbox = gr . Textbox ( label = ' Input ' )
2023-01-15 22:16:46 +01:00
btn = gr . Button ( " Generate " )
2023-01-09 21:23:43 +01:00
with gr . Row ( ) :
2023-01-19 02:44:47 +01:00
stop = gr . Button ( " Stop " )
2023-01-22 06:19:58 +01:00
btn_regenerate = gr . Button ( " Regenerate " )
btn_remove_last = gr . Button ( " Remove last " )
btn_clear = gr . Button ( " Clear history " )
2023-01-13 19:02:17 +01:00
2023-01-15 22:16:46 +01:00
with gr . Row ( ) :
2023-01-08 02:52:46 +01:00
with gr . Column ( ) :
2023-01-20 21:03:09 +01:00
length_slider = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-01-22 04:02:46 +01:00
with gr . Row ( ) :
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
create_refresh_button ( model_menu , lambda : None , lambda : { " choices " : get_available_models ( ) } , " refresh-button " )
2023-01-15 22:16:46 +01:00
with gr . Column ( ) :
2023-01-22 07:15:35 +01:00
history_size_slider = gr . Slider ( minimum = settings [ ' history_size_min ' ] , maximum = settings [ ' history_size_max ' ] , step = 1 , label = ' Chat history size in prompt (0 for no limit) ' , value = settings [ ' history_size ' ] )
2023-01-22 04:02:46 +01:00
with gr . Row ( ) :
2023-01-24 00:49:44 +01:00
preset_menu = gr . Dropdown ( choices = available_presets , value = settings [ f ' preset { suffix } ' ] , label = ' Generation parameters preset ' )
2023-01-22 04:02:46 +01:00
create_refresh_button ( preset_menu , lambda : None , lambda : { " choices " : get_available_presets ( ) } , " refresh-button " )
2023-01-15 22:16:46 +01:00
2023-01-19 22:58:45 +01:00
name1 = gr . Textbox ( value = settings [ f ' name1 { suffix } ' ] , lines = 1 , label = ' Your name ' )
name2 = gr . Textbox ( value = settings [ f ' name2 { suffix } ' ] , lines = 1 , label = ' Bot \' s name ' )
context = gr . Textbox ( value = settings [ f ' context { suffix } ' ] , lines = 2 , label = ' Context ' )
2023-01-15 22:16:46 +01:00
with gr . Row ( ) :
2023-01-22 04:02:46 +01:00
character_menu = gr . Dropdown ( choices = available_characters , value = " None " , label = ' Character ' )
create_refresh_button ( character_menu , lambda : None , lambda : { " choices " : get_available_characters ( ) } , " refresh-button " )
2023-01-19 20:46:46 +01:00
with gr . Row ( ) :
2023-01-19 22:58:45 +01:00
check = gr . Checkbox ( value = settings [ f ' stop_at_newline { suffix } ' ] , label = ' Stop generating at new line character? ' )
2023-01-19 18:03:47 +01:00
with gr . Row ( ) :
2023-01-22 04:22:50 +01:00
with gr . Tab ( ' Download chat history ' ) :
2023-01-19 18:03:47 +01:00
download = gr . File ( )
2023-01-22 04:22:50 +01:00
save_btn = gr . Button ( value = " Click me " )
2023-01-22 05:24:16 +01:00
with gr . Tab ( ' Upload chat history ' ) :
upload = gr . File ( type = ' binary ' )
2023-01-25 19:45:25 +01:00
with gr . Tab ( ' Upload character ' ) :
upload_char = gr . File ( type = ' binary ' )
2023-01-15 22:16:46 +01:00
2023-01-20 21:03:09 +01:00
input_params = [ textbox , length_slider , preset_menu , model_menu , name1 , name2 , context , check , history_size_slider ]
2023-01-15 16:20:04 +01:00
if args . cai_chat :
2023-01-20 21:03:09 +01:00
gen_event = btn . click ( cai_chatbot_wrapper , input_params , display1 , show_progress = args . no_stream , api_name = " textgen " )
gen_event2 = textbox . submit ( cai_chatbot_wrapper , input_params , display1 , show_progress = args . no_stream )
2023-01-15 16:20:04 +01:00
else :
2023-01-20 21:03:09 +01:00
gen_event = btn . click ( chatbot_wrapper , input_params , display1 , show_progress = args . no_stream , api_name = " textgen " )
gen_event2 = textbox . submit ( chatbot_wrapper , input_params , display1 , show_progress = args . no_stream )
2023-01-22 06:19:58 +01:00
gen_event3 = btn_regenerate . click ( regenerate_wrapper , input_params , display1 , show_progress = args . no_stream )
2023-01-15 16:20:04 +01:00
2023-01-25 14:52:35 +01:00
btn_clear . click ( clear_chat_log , [ character_menu , name1 , name2 ] , display1 )
2023-01-22 23:40:22 +01:00
btn_remove_last . click ( remove_last_message , [ name1 , name2 ] , [ display1 , textbox ] , show_progress = False )
2023-01-08 05:10:02 +01:00
btn . click ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-22 06:19:58 +01:00
btn_regenerate . click ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-08 05:33:45 +01:00
textbox . submit ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-22 06:19:58 +01:00
stop . click ( None , None , None , cancels = [ gen_event , gen_event2 , gen_event3 ] )
2023-01-19 18:03:47 +01:00
save_btn . click ( save_history , inputs = [ ] , outputs = [ download ] )
2023-01-19 20:46:46 +01:00
character_menu . change ( load_character , [ character_menu , name1 , name2 ] , [ name2 , context , display1 ] )
2023-01-25 19:45:25 +01:00
upload . upload ( upload_history , [ upload , name1 , name2 ] , [ ] )
upload_char . upload ( upload_character , [ upload_char , name1 , name2 ] , [ character_menu ] )
2023-01-19 20:46:46 +01:00
2023-01-19 19:05:42 +01:00
if args . cai_chat :
upload . upload ( redraw_html , [ name1 , name2 ] , [ display1 ] )
else :
2023-01-22 06:32:54 +01:00
upload . upload ( lambda : remove_example_dialogue_from_history ( history ) , [ ] , [ display1 ] )
2023-01-19 18:03:47 +01:00
2023-01-19 02:44:47 +01:00
elif args . notebook :
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
with gr . Tab ( ' Raw ' ) :
textbox = gr . Textbox ( value = default_text , lines = 23 )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
btn = gr . Button ( " Generate " )
stop = gr . Button ( " Stop " )
2023-01-11 05:33:57 +01:00
2023-01-19 02:44:47 +01:00
length_slider = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
with gr . Row ( ) :
with gr . Column ( ) :
2023-01-22 04:02:46 +01:00
with gr . Row ( ) :
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
create_refresh_button ( model_menu , lambda : None , lambda : { " choices " : get_available_models ( ) } , " refresh-button " )
2023-01-19 02:44:47 +01:00
with gr . Column ( ) :
2023-01-22 04:02:46 +01:00
with gr . Row ( ) :
2023-01-24 00:49:44 +01:00
preset_menu = gr . Dropdown ( choices = available_presets , value = settings [ ' preset ' ] , label = ' Generation parameters preset ' )
2023-01-22 04:02:46 +01:00
create_refresh_button ( preset_menu , lambda : None , lambda : { " choices " : get_available_presets ( ) } , " refresh-button " )
2023-01-19 02:44:47 +01:00
2023-01-19 03:56:42 +01:00
gen_event = btn . click ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ textbox , markdown , html ] , show_progress = args . no_stream , api_name = " textgen " )
gen_event2 = textbox . submit ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ textbox , markdown , html ] , show_progress = args . no_stream )
2023-01-19 02:44:47 +01:00
stop . click ( None , None , None , cancels = [ gen_event , gen_event2 ] )
else :
2023-01-09 00:10:31 +01:00
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
2023-01-07 02:05:37 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Input ' )
2023-01-15 19:23:41 +01:00
length_slider = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-01-22 04:02:46 +01:00
with gr . Row ( ) :
2023-01-24 00:49:44 +01:00
preset_menu = gr . Dropdown ( choices = available_presets , value = settings [ ' preset ' ] , label = ' Generation parameters preset ' )
2023-01-22 04:02:46 +01:00
create_refresh_button ( preset_menu , lambda : None , lambda : { " choices " : get_available_presets ( ) } , " refresh-button " )
with gr . Row ( ) :
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
create_refresh_button ( model_menu , lambda : None , lambda : { " choices " : get_available_models ( ) } , " refresh-button " )
2023-01-07 02:05:37 +01:00
btn = gr . Button ( " Generate " )
2023-01-19 02:44:47 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
cont = gr . Button ( " Continue " )
with gr . Column ( ) :
stop = gr . Button ( " Stop " )
2023-01-07 02:05:37 +01:00
with gr . Column ( ) :
with gr . Tab ( ' Raw ' ) :
2023-01-11 05:36:11 +01:00
output_textbox = gr . Textbox ( lines = 15 , label = ' Output ' )
2023-01-07 02:05:37 +01:00
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-07 03:14:08 +01:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-07 02:05:37 +01:00
2023-01-19 03:56:42 +01:00
gen_event = btn . click ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream , api_name = " textgen " )
gen_event2 = textbox . submit ( generate_reply , [ textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream )
cont_event = cont . click ( generate_reply , [ output_textbox , length_slider , preset_menu , model_menu ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream )
2023-01-19 02:44:47 +01:00
stop . click ( None , None , None , cancels = [ gen_event , gen_event2 , cont_event ] )
2022-12-21 17:27:31 +01:00
2023-01-25 20:10:35 +01:00
interface . queue ( )
2023-01-21 03:45:16 +01:00
if args . listen :
2023-01-19 21:31:29 +01:00
interface . launch ( share = args . share , server_name = " 0.0.0.0 " )
2023-01-21 03:45:16 +01:00
else :
interface . launch ( share = args . share )