2022-12-21 17:27:31 +01:00
import re
2023-01-22 04:02:46 +01:00
import gc
2023-01-06 05:33:21 +01:00
import time
import glob
2022-12-21 17:27:31 +01:00
import torch
2023-01-06 23:56:44 +01:00
import argparse
2023-01-15 19:23:41 +01:00
import json
2023-01-28 23:16:37 +01:00
import io
2023-01-29 00:18:23 +01:00
import base64
2023-01-27 04:40:39 +01:00
import sys
2023-02-01 13:57:27 +01:00
import os
2023-02-03 13:02:35 +01:00
from datetime import datetime
2023-01-07 20:33:43 +01:00
from pathlib import Path
2023-01-28 23:16:37 +01:00
from PIL import Image
2023-01-27 16:01:11 +01:00
import copy
2022-12-21 17:27:31 +01:00
import gradio as gr
2023-01-15 04:39:51 +01:00
import warnings
2023-01-19 16:20:57 +01:00
from tqdm import tqdm
2023-01-22 04:02:46 +01:00
import transformers
2023-02-01 13:57:27 +01:00
from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig
2023-01-22 04:02:46 +01:00
from modules . html_generator import *
from modules . ui import *
2023-01-25 14:17:55 +01:00
from modules . stopping_criteria import _SentinelTokenStoppingCriteria
2022-12-21 17:27:31 +01:00
2023-01-15 19:23:41 +01:00
transformers . logging . set_verbosity_error ( )
2023-02-03 03:36:28 +01:00
parser = argparse . ArgumentParser ( formatter_class = lambda prog : argparse . HelpFormatter ( prog , max_help_position = 54 ) )
2023-01-07 00:22:26 +01:00
parser . add_argument ( ' --model ' , type = str , help = ' Name of the model to load by default. ' )
2023-01-16 14:10:09 +01:00
parser . add_argument ( ' --notebook ' , action = ' store_true ' , help = ' Launch the web UI in notebook mode, where the output is written to the same text box as the input. ' )
parser . add_argument ( ' --chat ' , action = ' store_true ' , help = ' Launch the web UI in chat mode. ' )
2023-01-25 23:39:36 +01:00
parser . add_argument ( ' --cai-chat ' , action = ' store_true ' , help = ' Launch the web UI in chat mode with a style similar to Character.AI \' s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot \' s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture. ' )
2023-01-09 14:58:46 +01:00
parser . add_argument ( ' --cpu ' , action = ' store_true ' , help = ' Use the CPU to generate text. ' )
2023-01-11 03:16:33 +01:00
parser . add_argument ( ' --load-in-8bit ' , action = ' store_true ' , help = ' Load the model with 8-bit precision. ' )
2023-01-19 15:09:24 +01:00
parser . add_argument ( ' --auto-devices ' , action = ' store_true ' , help = ' Automatically split the model across the available GPU(s) and CPU. ' )
parser . add_argument ( ' --disk ' , action = ' store_true ' , help = ' If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. ' )
2023-01-21 06:48:06 +01:00
parser . add_argument ( ' --disk-cache-dir ' , type = str , help = ' Directory to save the disk cache to. Defaults to " cache/ " . ' )
2023-01-21 04:33:41 +01:00
parser . add_argument ( ' --gpu-memory ' , type = int , help = ' Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. ' )
2023-01-21 07:05:55 +01:00
parser . add_argument ( ' --cpu-memory ' , type = int , help = ' Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99. ' )
2023-02-01 13:57:27 +01:00
parser . add_argument ( ' --deepspeed ' , action = ' store_true ' , help = ' Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. ' )
2023-02-02 14:25:09 +01:00
parser . add_argument ( ' --nvme-offload-dir ' , type = str , help = ' DeepSpeed: Directory to use for ZeRO-3 NVME offloading. ' )
parser . add_argument ( ' --bf16 ' , action = ' store_true ' , help = ' DeepSpeed: Instantiate the model with bfloat16 precision. Requires NVIDIA Ampere GPU. ' )
2023-02-02 14:39:37 +01:00
parser . add_argument ( ' --local_rank ' , type = int , default = 0 , help = ' DeepSpeed: Optional argument for distributed setups. ' )
2023-01-22 20:19:11 +01:00
parser . add_argument ( ' --no-stream ' , action = ' store_true ' , help = ' Don \' t stream the text output in real time. This improves the text generation performance. ' )
2023-01-16 20:35:45 +01:00
parser . add_argument ( ' --settings ' , type = str , help = ' Load the default interface settings from this json file. See settings-template.json for an example. ' )
2023-01-27 04:40:39 +01:00
parser . add_argument ( ' --extensions ' , type = str , help = ' The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, " like,this " . ' )
2023-01-21 03:45:16 +01:00
parser . add_argument ( ' --listen ' , action = ' store_true ' , help = ' Make the web UI reachable from your local network. ' )
2023-01-29 06:54:36 +01:00
parser . add_argument ( ' --listen-port ' , type = int , help = ' The listening port that the server will use. ' )
2023-01-19 21:31:29 +01:00
parser . add_argument ( ' --share ' , action = ' store_true ' , help = ' Create a public URL. This is useful for running the web UI on Google Colab or similar. ' )
2023-01-26 06:12:53 +01:00
parser . add_argument ( ' --verbose ' , action = ' store_true ' , help = ' Print the prompts to the terminal. ' )
2023-01-06 23:56:44 +01:00
args = parser . parse_args ( )
2023-01-15 04:39:51 +01:00
2023-01-22 20:19:11 +01:00
if ( args . chat or args . cai_chat ) and not args . no_stream :
2023-01-25 18:37:41 +01:00
print ( " Warning: chat mode currently becomes somewhat slower with text streaming on. \n Consider starting the web UI with the --no-stream option. \n " )
2023-01-29 18:27:22 +01:00
2023-01-15 19:23:41 +01:00
settings = {
' max_new_tokens ' : 200 ,
' max_new_tokens_min ' : 1 ,
' max_new_tokens_max ' : 2000 ,
' preset ' : ' NovelAI-Sphinx Moth ' ,
' name1 ' : ' Person 1 ' ,
' name2 ' : ' Person 2 ' ,
' context ' : ' This is a conversation between two people. ' ,
' prompt ' : ' Common sense questions and answers \n \n Question: \n Factual answer: ' ,
' prompt_gpt4chan ' : ' ----- \n --- 865467536 \n Input text \n --- 865467537 \n ' ,
' stop_at_newline ' : True ,
2023-01-22 21:17:35 +01:00
' history_size ' : 0 ,
2023-01-20 21:03:09 +01:00
' history_size_min ' : 0 ,
' history_size_max ' : 64 ,
2023-01-19 22:58:45 +01:00
' preset_pygmalion ' : ' Pygmalion ' ,
' name1_pygmalion ' : ' You ' ,
' name2_pygmalion ' : ' Kawaii ' ,
2023-01-22 04:49:59 +01:00
' context_pygmalion ' : " Kawaii ' s persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes. \n <START> " ,
2023-01-19 20:46:46 +01:00
' stop_at_newline_pygmalion ' : False ,
2023-01-15 19:23:41 +01:00
}
2023-01-16 20:35:45 +01:00
if args . settings is not None and Path ( args . settings ) . exists ( ) :
2023-01-30 18:17:12 +01:00
new_settings = json . loads ( open ( Path ( args . settings ) , ' r ' ) . read ( ) )
2023-01-16 20:35:45 +01:00
for item in new_settings :
2023-01-29 03:21:40 +01:00
settings [ item ] = new_settings [ item ]
2023-01-15 04:39:51 +01:00
2023-02-01 13:57:27 +01:00
if args . deepspeed :
import deepspeed
from transformers . deepspeed import HfDeepSpeedConfig , is_deepspeed_zero3_enabled
2023-02-02 14:39:37 +01:00
from modules . deepspeed_parameters import generate_ds_config
2023-02-01 13:57:27 +01:00
# Distributed setup
2023-02-02 16:15:44 +01:00
local_rank = args . local_rank if args . local_rank is not None else int ( os . getenv ( " LOCAL_RANK " , " 0 " ) )
2023-02-01 13:57:27 +01:00
world_size = int ( os . getenv ( " WORLD_SIZE " , " 1 " ) )
torch . cuda . set_device ( local_rank )
deepspeed . init_distributed ( )
2023-02-02 14:39:37 +01:00
ds_config = generate_ds_config ( args . bf16 , 1 * world_size , args . nvme_offload_dir )
2023-02-01 13:57:27 +01:00
dschf = HfDeepSpeedConfig ( ds_config ) # Keep this object alive for the Transformers integration
2022-12-21 17:27:31 +01:00
def load_model ( model_name ) :
2023-01-06 05:41:52 +01:00
print ( f " Loading { model_name } ... " )
2022-12-21 17:27:31 +01:00
t0 = time . time ( )
2023-01-06 05:41:52 +01:00
2023-01-11 03:16:33 +01:00
# Default settings
2023-02-01 13:57:27 +01:00
if not ( args . cpu or args . load_in_8bit or args . auto_devices or args . disk or args . gpu_memory is not None or args . cpu_memory is not None or args . deepspeed ) :
2023-01-11 03:16:33 +01:00
if Path ( f " torch-dumps/ { model_name } .pt " ) . exists ( ) :
print ( " Loading in .pt format... " )
model = torch . load ( Path ( f " torch-dumps/ { model_name } .pt " ) )
elif model_name . lower ( ) . startswith ( ( ' gpt-neo ' , ' opt- ' , ' galactica ' ) ) and any ( size in model_name . lower ( ) for size in ( ' 13b ' , ' 20b ' , ' 30b ' ) ) :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , device_map = ' auto ' , load_in_8bit = True )
else :
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , low_cpu_mem_usage = True , torch_dtype = torch . float16 ) . cuda ( )
2023-02-01 13:57:27 +01:00
# DeepSpeed ZeRO-3
elif args . deepspeed :
2023-02-02 16:15:44 +01:00
model = AutoModelForCausalLM . from_pretrained ( Path ( f " models/ { model_name } " ) , torch_dtype = torch . bfloat16 if args . bf16 else torch . float16 )
model = deepspeed . initialize ( model = model , config_params = ds_config , model_parameters = None , optimizer = None , lr_scheduler = None ) [ 0 ]
2023-02-01 13:57:27 +01:00
model . module . eval ( ) # Inference
print ( f " DeepSpeed ZeRO-3 is enabled: { is_deepspeed_zero3_enabled ( ) } " )
2023-01-11 03:16:33 +01:00
# Custom
2023-01-06 06:54:33 +01:00
else :
2023-01-11 03:39:50 +01:00
command = " AutoModelForCausalLM.from_pretrained "
2023-02-01 14:01:44 +01:00
settings = [ " low_cpu_mem_usage=True " ]
2023-01-11 03:16:33 +01:00
2023-01-09 20:28:04 +01:00
if args . cpu :
2023-01-31 17:24:05 +01:00
settings . append ( " low_cpu_mem_usage=True " )
2023-01-11 03:16:33 +01:00
settings . append ( " torch_dtype=torch.float32 " )
2023-01-09 20:28:04 +01:00
else :
2023-01-16 03:01:51 +01:00
settings . append ( " device_map= ' auto ' " )
2023-01-30 18:17:12 +01:00
settings . append ( " load_in_8bit=True " if args . load_in_8bit else " torch_dtype=torch.float16 " )
2023-02-01 00:47:05 +01:00
if args . gpu_memory :
settings . append ( f " max_memory= {{ 0: ' { args . gpu_memory or ' 99 ' } GiB ' , ' cpu ' : ' { args . cpu_memory or ' 99 ' } GiB ' }} " )
2023-02-01 14:43:28 +01:00
elif ( args . gpu_memory or args . cpu_memory ) and not args . load_in_8bit :
2023-02-01 00:47:05 +01:00
total_mem = ( torch . cuda . get_device_properties ( 0 ) . total_memory / ( 1024 * 1024 ) )
suggestion = round ( ( total_mem - 1000 ) / 1000 ) * 1000
if total_mem - suggestion < 800 :
suggestion - = 1000
suggestion = int ( round ( suggestion / 1000 ) )
print ( f " \033 [1;32;1mAuto-assiging --gpu-memory { suggestion } for your GPU to try to prevent out-of-memory errors. \n You can manually set other values. \033 [0;37;0m " )
settings . append ( f " max_memory= {{ 0: ' { suggestion } GiB ' , ' cpu ' : ' { args . cpu_memory or ' 99 ' } GiB ' }} " )
2023-01-19 15:09:24 +01:00
if args . disk :
2023-01-30 18:17:12 +01:00
settings . append ( f " offload_folder= ' { args . disk_cache_dir or ' cache ' } ' " )
2023-01-11 03:16:33 +01:00
2023-02-03 03:36:28 +01:00
command = f " { command } (Path(f ' models/ { model_name } ' ), { ' , ' . join ( set ( settings ) ) } ) "
2023-01-11 03:16:33 +01:00
model = eval ( command )
2022-12-21 17:27:31 +01:00
2023-01-06 06:54:33 +01:00
# Loading the tokenizer
2023-01-11 05:10:11 +01:00
if model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) and Path ( f " models/gpt-j-6B/ " ) . exists ( ) :
2023-01-07 20:33:43 +01:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( " models/gpt-j-6B/ " ) )
2022-12-21 17:27:31 +01:00
else :
2023-01-07 20:33:43 +01:00
tokenizer = AutoTokenizer . from_pretrained ( Path ( f " models/ { model_name } / " ) )
2023-01-16 17:43:23 +01:00
tokenizer . truncation_side = ' left '
2022-12-21 17:27:31 +01:00
2023-01-06 06:06:59 +01:00
print ( f " Loaded the model in { ( time . time ( ) - t0 ) : .2f } seconds. " )
2022-12-21 17:27:31 +01:00
return model , tokenizer
2023-02-08 02:08:21 +01:00
def load_model_wrapper ( selected_model ) :
global model_name , model , tokenizer
if selected_model != model_name :
model_name = selected_model
model = tokenizer = None
if not args . cpu :
gc . collect ( )
torch . cuda . empty_cache ( )
model , tokenizer = load_model ( model_name )
def load_preset_values ( preset_menu , return_dict = False ) :
settings = {
' do_sample ' : True ,
' temperature ' : 1 ,
' top_p ' : 1 ,
' typical_p ' : 1 ,
' repetition_penalty ' : 1 ,
' top_k ' : 50 ,
2023-02-08 03:11:04 +01:00
' num_beams ' : 1 ,
' min_length ' : 0 ,
' length_penalty ' : 1 ,
' no_repeat_ngram_size ' : 0 ,
' early_stopping ' : False ,
2023-02-08 02:08:21 +01:00
}
with open ( Path ( f ' presets/ { preset_menu } .txt ' ) , ' r ' ) as infile :
preset = infile . read ( )
for i in preset . split ( ' , ' ) :
i = i . strip ( ) . split ( ' = ' )
if len ( i ) == 2 and i [ 0 ] . strip ( ) != ' tokens ' :
settings [ i [ 0 ] . strip ( ) ] = eval ( i [ 1 ] . strip ( ) )
settings [ ' temperature ' ] = min ( 1.99 , settings [ ' temperature ' ] )
if return_dict :
return settings
else :
2023-02-08 03:11:04 +01:00
return settings [ ' do_sample ' ] , settings [ ' temperature ' ] , settings [ ' top_p ' ] , settings [ ' typical_p ' ] , settings [ ' repetition_penalty ' ] , settings [ ' top_k ' ] , settings [ ' min_length ' ] , settings [ ' no_repeat_ngram_size ' ] , settings [ ' num_beams ' ] , settings [ ' length_penalty ' ] , settings [ ' early_stopping ' ]
2023-02-08 02:08:21 +01:00
2023-01-06 06:26:33 +01:00
# Removes empty replies from gpt4chan outputs
2022-12-21 17:27:31 +01:00
def fix_gpt4chan ( s ) :
for i in range ( 10 ) :
s = re . sub ( " --- [0-9]* \n >>[0-9]* \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n * \n --- " , " --- " , s )
s = re . sub ( " --- [0-9]* \n \n \n --- " , " --- " , s )
return s
2023-01-16 20:35:45 +01:00
# Fix the LaTeX equations in galactica
2023-01-07 05:56:21 +01:00
def fix_galactica ( s ) :
s = s . replace ( r ' \ [ ' , r ' $ ' )
s = s . replace ( r ' \ ] ' , r ' $ ' )
2023-01-07 16:13:09 +01:00
s = s . replace ( r ' \ ( ' , r ' $ ' )
s = s . replace ( r ' \ ) ' , r ' $ ' )
s = s . replace ( r ' $$ ' , r ' $ ' )
2023-01-07 05:56:21 +01:00
return s
2023-01-25 14:17:55 +01:00
def encode ( prompt , tokens_to_generate = 0 , add_special_tokens = True ) :
2023-02-02 17:31:32 +01:00
input_ids = tokenizer . encode ( str ( prompt ) , return_tensors = ' pt ' , truncation = True , max_length = 2048 - tokens_to_generate , add_special_tokens = add_special_tokens )
2023-01-23 17:36:01 +01:00
if args . cpu :
2023-02-02 16:15:44 +01:00
return input_ids
2023-02-02 17:31:32 +01:00
elif args . deepspeed :
2023-02-02 16:15:44 +01:00
return input_ids . to ( device = local_rank )
2023-02-02 17:31:32 +01:00
else :
return input_ids . cuda ( )
2023-01-18 00:16:23 +01:00
2023-01-19 14:43:05 +01:00
def decode ( output_ids ) :
reply = tokenizer . decode ( output_ids , skip_special_tokens = True )
reply = reply . replace ( r ' <|endoftext|> ' , ' ' )
return reply
def formatted_outputs ( reply , model_name ) :
2023-01-19 18:57:01 +01:00
if not ( args . chat or args . cai_chat ) :
if model_name . lower ( ) . startswith ( ' galactica ' ) :
reply = fix_galactica ( reply )
return reply , reply , generate_basic_html ( reply )
2023-01-22 02:13:01 +01:00
elif model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) :
2023-01-19 18:57:01 +01:00
reply = fix_gpt4chan ( reply )
return reply , ' Only applicable for GALACTICA models. ' , generate_4chan_html ( reply )
else :
return reply , ' Only applicable for GALACTICA models. ' , generate_basic_html ( reply )
2023-01-19 14:43:05 +01:00
else :
2023-01-19 18:57:01 +01:00
return reply
2023-01-19 14:43:05 +01:00
2023-02-08 03:11:04 +01:00
def generate_reply ( question , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , eos_token = None , stopping_string = None ) :
2023-02-08 02:08:21 +01:00
global model_name , model , tokenizer
2022-12-21 17:27:31 +01:00
2023-01-27 04:40:39 +01:00
original_question = question
if not ( args . chat or args . cai_chat ) :
question = apply_extensions ( question , " input " )
2023-01-26 06:12:53 +01:00
if args . verbose :
print ( f " \n \n { question } \n -------------------- \n " )
2023-02-02 16:15:44 +01:00
input_ids = encode ( question , tokens )
2023-02-02 17:47:08 +01:00
cuda = " " if ( args . cpu or args . deepspeed ) else " .cuda() "
n = tokenizer . eos_token_id if eos_token is None else tokenizer . encode ( eos_token , return_tensors = ' pt ' ) [ 0 ] [ - 1 ]
2023-01-25 14:17:55 +01:00
if stopping_string is not None :
2023-01-26 17:45:19 +01:00
# The stopping_criteria code below was copied from
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
2023-01-25 14:17:55 +01:00
t = encode ( stopping_string , 0 , add_special_tokens = False )
stopping_criteria_list = transformers . StoppingCriteriaList ( [
_SentinelTokenStoppingCriteria (
2023-01-25 23:37:44 +01:00
sentinel_token_ids = t ,
starting_idx = len ( input_ids [ 0 ] )
)
2023-01-25 14:17:55 +01:00
] )
else :
stopping_criteria_list = None
2023-01-19 14:43:05 +01:00
2023-02-08 02:08:21 +01:00
generate_params = [
f " eos_token_id= { n } " ,
f " stopping_criteria=stopping_criteria_list " ,
f " do_sample= { do_sample } " ,
f " temperature= { temperature } " ,
f " top_p= { top_p } " ,
f " typical_p= { typical_p } " ,
f " repetition_penalty= { repetition_penalty } " ,
f " top_k= { top_k } " ,
2023-02-08 03:11:04 +01:00
f " min_length= { min_length } " ,
f " no_repeat_ngram_size= { no_repeat_ngram_size } " ,
f " num_beams= { num_beams } " ,
f " length_penalty= { length_penalty } " ,
f " early_stopping= { early_stopping } " ,
2023-02-08 02:08:21 +01:00
]
2023-02-08 03:11:04 +01:00
print ( generate_params )
2023-02-02 17:47:08 +01:00
if args . deepspeed :
generate_params . append ( " synced_gpus=True " )
2023-02-08 02:08:21 +01:00
if args . no_stream :
generate_params . append ( f " max_new_tokens=tokens " )
else :
generate_params . append ( f " max_new_tokens=8 " )
2023-02-02 17:47:08 +01:00
2023-01-19 14:43:05 +01:00
# Generate the entire reply at once
if args . no_stream :
2023-01-23 00:07:19 +01:00
t0 = time . time ( )
2023-02-02 16:15:44 +01:00
with torch . no_grad ( ) :
2023-02-08 02:08:21 +01:00
output = eval ( f " model.generate(input_ids, { ' , ' . join ( generate_params ) } ) { cuda } " )
2023-01-19 14:43:05 +01:00
reply = decode ( output [ 0 ] )
2023-01-23 00:07:19 +01:00
t1 = time . time ( )
2023-02-03 13:11:11 +01:00
print ( f " Output generated in { ( t1 - t0 ) : .2f } seconds ( { ( len ( output [ 0 ] ) - len ( input_ids [ 0 ] ) ) / ( t1 - t0 ) / 8 : .2f } it/s, { len ( output [ 0 ] ) - len ( input_ids [ 0 ] ) } tokens) " )
2023-01-27 04:40:39 +01:00
if not ( args . chat or args . cai_chat ) :
reply = original_question + apply_extensions ( reply [ len ( question ) : ] , " output " )
2023-01-19 14:43:05 +01:00
yield formatted_outputs ( reply , model_name )
# Generate the reply 1 token at a time
else :
2023-01-27 04:40:39 +01:00
yield formatted_outputs ( original_question , model_name )
2023-01-25 14:38:26 +01:00
for i in tqdm ( range ( tokens / / 8 + 1 ) ) :
2023-02-02 16:15:44 +01:00
with torch . no_grad ( ) :
2023-02-08 02:08:21 +01:00
output = eval ( f " model.generate(input_ids, { ' , ' . join ( generate_params ) } ) { cuda } " )
2023-01-19 14:43:05 +01:00
reply = decode ( output [ 0 ] )
2023-01-27 04:40:39 +01:00
if not ( args . chat or args . cai_chat ) :
reply = original_question + apply_extensions ( reply [ len ( question ) : ] , " output " )
2023-01-19 14:43:05 +01:00
yield formatted_outputs ( reply , model_name )
2023-01-19 03:56:42 +01:00
input_ids = output
2023-01-26 02:27:04 +01:00
if output [ 0 ] [ - 1 ] == n :
break
2023-01-19 01:37:21 +01:00
2023-01-27 04:40:39 +01:00
def apply_extensions ( text , typ ) :
global available_extensions , extension_state
for ext in sorted ( extension_state , key = lambda x : extension_state [ x ] [ 1 ] ) :
if extension_state [ ext ] [ 0 ] == True :
ext_string = f " extensions. { ext } .script "
2023-01-29 14:11:59 +01:00
if typ == " input " and hasattr ( eval ( ext_string ) , " input_modifier " ) :
2023-01-27 04:40:39 +01:00
text = eval ( f " { ext_string } .input_modifier(text) " )
2023-01-29 14:11:59 +01:00
elif typ == " output " and hasattr ( eval ( ext_string ) , " output_modifier " ) :
2023-01-27 04:40:39 +01:00
text = eval ( f " { ext_string } .output_modifier(text) " )
2023-01-29 14:11:59 +01:00
elif typ == " bot_prefix " and hasattr ( eval ( ext_string ) , " bot_prefix_modifier " ) :
text = eval ( f " { ext_string } .bot_prefix_modifier(text) " )
2023-01-27 04:40:39 +01:00
return text
2023-01-29 13:48:18 +01:00
def update_extensions_parameters ( * kwargs ) :
i = 0
for ext in sorted ( extension_state , key = lambda x : extension_state [ x ] [ 1 ] ) :
if extension_state [ ext ] [ 0 ] == True :
params = eval ( f " extensions. { ext } .script.params " )
for param in params :
if len ( kwargs ) > = i + 1 :
params [ param ] = eval ( f " kwargs[ { i } ] " )
i + = 1
2023-02-08 02:08:21 +01:00
def get_available_models ( ) :
return sorted ( set ( [ item . replace ( ' .pt ' , ' ' ) for item in map ( lambda x : str ( x . name ) , list ( Path ( ' models/ ' ) . glob ( ' * ' ) ) + list ( Path ( ' torch-dumps/ ' ) . glob ( ' * ' ) ) ) if not item . endswith ( ' .txt ' ) ] ) , key = str . lower )
def get_available_presets ( ) :
return sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' presets ' ) . glob ( ' *.txt ' ) ) ) , key = str . lower )
def get_available_characters ( ) :
return [ " None " ] + sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( ' characters ' ) . glob ( ' *.json ' ) ) ) , key = str . lower )
def get_available_extensions ( ) :
return sorted ( set ( map ( lambda x : x . parts [ 1 ] , Path ( ' extensions ' ) . glob ( ' */script.py ' ) ) ) , key = str . lower )
2023-01-29 13:48:18 +01:00
def create_extensions_block ( ) :
extensions_ui_elements = [ ]
default_values = [ ]
gr . Markdown ( ' ## Extensions parameters ' )
for ext in sorted ( extension_state , key = lambda x : extension_state [ x ] [ 1 ] ) :
if extension_state [ ext ] [ 0 ] == True :
params = eval ( f " extensions. { ext } .script.params " )
for param in params :
_id = f " { ext } - { param } "
default_value = settings [ _id ] if _id in settings else params [ param ]
default_values . append ( default_value )
if type ( params [ param ] ) == str :
extensions_ui_elements . append ( gr . Textbox ( value = default_value , label = f " { ext } - { param } " ) )
elif type ( params [ param ] ) in [ int , float ] :
extensions_ui_elements . append ( gr . Number ( value = default_value , label = f " { ext } - { param } " ) )
elif type ( params [ param ] ) == bool :
extensions_ui_elements . append ( gr . Checkbox ( value = default_value , label = f " { ext } - { param } " ) )
update_extensions_parameters ( * default_values )
btn_extensions = gr . Button ( " Apply " )
btn_extensions . click ( update_extensions_parameters , [ * extensions_ui_elements ] , [ ] )
2023-02-08 02:08:21 +01:00
def create_settings_menus ( ) :
defaults = load_preset_values ( settings [ f ' preset { suffix } ' ] , return_dict = True )
2023-01-22 04:49:59 +01:00
2023-02-08 02:08:21 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
with gr . Row ( ) :
model_menu = gr . Dropdown ( choices = available_models , value = model_name , label = ' Model ' )
create_refresh_button ( model_menu , lambda : None , lambda : { " choices " : get_available_models ( ) } , " refresh-button " )
with gr . Column ( ) :
with gr . Row ( ) :
preset_menu = gr . Dropdown ( choices = available_presets , value = settings [ f ' preset { suffix } ' ] , label = ' Generation parameters preset ' )
create_refresh_button ( preset_menu , lambda : None , lambda : { " choices " : get_available_presets ( ) } , " refresh-button " )
2023-01-22 04:49:59 +01:00
2023-02-08 02:08:21 +01:00
with gr . Accordion ( " Custom generation parameters " , open = False ) :
with gr . Row ( ) :
with gr . Column ( ) :
do_sample = gr . Checkbox ( value = defaults [ ' do_sample ' ] , label = " do_sample " )
temperature = gr . Slider ( 0.01 , 1.99 , value = defaults [ ' temperature ' ] , step = 0.01 , label = " temperature " )
top_p = gr . Slider ( 0.0 , 1.0 , value = defaults [ ' top_p ' ] , step = 0.01 , label = " top_p " )
typical_p = gr . Slider ( 0.0 , 1.0 , value = defaults [ ' typical_p ' ] , step = 0.01 , label = " typical_p " )
2023-02-08 03:11:04 +01:00
with gr . Column ( ) :
2023-02-08 03:23:39 +01:00
repetition_penalty = gr . Slider ( 1.0 , 4.99 , value = defaults [ ' repetition_penalty ' ] , step = 0.01 , label = " repetition_penalty " )
2023-02-08 02:08:21 +01:00
top_k = gr . Slider ( 0 , 200 , value = defaults [ ' top_k ' ] , step = 1 , label = " top_k " )
2023-02-08 03:11:04 +01:00
no_repeat_ngram_size = gr . Slider ( 0 , 20 , step = 1 , value = defaults [ " no_repeat_ngram_size " ] , label = " no_repeat_ngram_size " )
gr . Markdown ( " Special parameters (only use them if you really need them): " )
with gr . Row ( ) :
with gr . Column ( ) :
num_beams = gr . Slider ( 0 , 20 , step = 1 , value = defaults [ " num_beams " ] , label = " num_beams " )
length_penalty = gr . Slider ( 0 , 5 , value = defaults [ " length_penalty " ] , label = " length_penalty " )
with gr . Column ( ) :
min_length = gr . Slider ( 0 , 2000 , step = 1 , value = defaults [ " min_length " ] if args . no_stream else 0 , label = " min_length " , interactive = args . no_stream )
early_stopping = gr . Checkbox ( value = defaults [ " early_stopping " ] , label = " early_stopping " )
2023-02-08 02:08:21 +01:00
model_menu . change ( load_model_wrapper , [ model_menu ] , [ ] )
2023-02-08 03:11:04 +01:00
preset_menu . change ( load_preset_values , [ preset_menu ] , [ do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping ] )
return preset_menu , do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping
2023-02-08 02:08:21 +01:00
# This gets the new line characters right.
def clean_chat_message ( text ) :
text = text . replace ( ' \n ' , ' \n \n ' )
text = re . sub ( r " \ n { 3,} " , " \n \n " , text )
text = text . strip ( )
return text
2023-01-22 04:49:59 +01:00
2023-02-08 02:08:21 +01:00
def generate_chat_prompt ( text , tokens , name1 , name2 , context , history_size , impersonate = False ) :
text = clean_chat_message ( text )
rows = [ f " { context . strip ( ) } \n " ]
i = len ( history [ ' internal ' ] ) - 1
count = 0
while i > = 0 and len ( encode ( ' ' . join ( rows ) , tokens ) [ 0 ] ) < 2048 - tokens :
rows . insert ( 1 , f " { name2 } : { history [ ' internal ' ] [ i ] [ 1 ] . strip ( ) } \n " )
count + = 1
if not ( history [ ' internal ' ] [ i ] [ 0 ] == ' <|BEGIN-VISIBLE-CHAT|> ' ) :
rows . insert ( 1 , f " { name1 } : { history [ ' internal ' ] [ i ] [ 0 ] . strip ( ) } \n " )
count + = 1
i - = 1
if history_size != 0 and count > = history_size :
break
if not impersonate :
rows . append ( f " { name1 } : { text } \n " )
rows . append ( apply_extensions ( f " { name2 } : " , " bot_prefix " ) )
limit = 3
else :
rows . append ( f " { name1 } : " )
limit = 2
while len ( rows ) > limit and len ( encode ( ' ' . join ( rows ) , tokens ) [ 0 ] ) > = 2048 - tokens :
rows . pop ( 1 )
rows . pop ( 1 )
question = ' ' . join ( rows )
return question
def extract_message_from_reply ( question , reply , current , other , check , extensions = False ) :
next_character_found = False
substring_found = False
previous_idx = [ m . start ( ) for m in re . finditer ( f " (^| \n ) { current } : " , question ) ]
idx = [ m . start ( ) for m in re . finditer ( f " (^| \n ) { current } : " , reply ) ]
idx = idx [ len ( previous_idx ) - 1 ]
if extensions :
reply = reply [ idx + 1 + len ( apply_extensions ( f " { current } : " , " bot_prefix " ) ) : ]
else :
reply = reply [ idx + 1 + len ( f " { current } : " ) : ]
if check :
reply = reply . split ( ' \n ' ) [ 0 ] . strip ( )
else :
idx = reply . find ( f " \n { other } : " )
if idx != - 1 :
reply = reply [ : idx ]
next_character_found = True
reply = clean_chat_message ( reply )
# Detect if something like "\nYo" is generated just before
# "\nYou:" is completed
tmp = f " \n { other } : "
for j in range ( 1 , len ( tmp ) ) :
if reply [ - j : ] == tmp [ : j ] :
substring_found = True
return reply , next_character_found , substring_found
2023-02-08 03:11:04 +01:00
def chatbot_wrapper ( text , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , name1 , name2 , context , check , history_size ) :
2023-02-08 02:08:21 +01:00
original_text = text
text = apply_extensions ( text , " input " )
question = generate_chat_prompt ( text , tokens , name1 , name2 , context , history_size )
history [ ' internal ' ] . append ( [ ' ' , ' ' ] )
history [ ' visible ' ] . append ( [ ' ' , ' ' ] )
eos_token = ' \n ' if check else None
2023-02-08 03:11:04 +01:00
for reply in generate_reply ( question , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , eos_token = eos_token , stopping_string = f " \n { name1 } : " ) :
2023-02-08 02:08:21 +01:00
reply , next_character_found , substring_found = extract_message_from_reply ( question , reply , name2 , name1 , check , extensions = True )
history [ ' internal ' ] [ - 1 ] = [ text , reply ]
history [ ' visible ' ] [ - 1 ] = [ original_text , apply_extensions ( reply , " output " ) ]
if not substring_found :
yield history [ ' visible ' ]
if next_character_found :
break
yield history [ ' visible ' ]
2023-02-08 03:11:04 +01:00
def impersonate_wrapper ( text , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , name1 , name2 , context , check , history_size ) :
2023-02-08 02:08:21 +01:00
question = generate_chat_prompt ( text , tokens , name1 , name2 , context , history_size , impersonate = True )
eos_token = ' \n ' if check else None
2023-02-08 03:11:04 +01:00
for reply in generate_reply ( question , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , eos_token = eos_token , stopping_string = f " \n { name2 } : " ) :
2023-02-08 02:08:21 +01:00
reply , next_character_found , substring_found = extract_message_from_reply ( question , reply , name1 , name2 , check , extensions = False )
if not substring_found :
yield apply_extensions ( reply , " output " )
if next_character_found :
break
yield apply_extensions ( reply , " output " )
2023-02-08 03:11:04 +01:00
def cai_chatbot_wrapper ( text , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , name1 , name2 , context , check , history_size ) :
for _history in chatbot_wrapper ( text , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , name1 , name2 , context , check , history_size ) :
2023-02-08 02:08:21 +01:00
yield generate_chat_html ( _history , name1 , name2 , character )
2023-02-08 03:11:04 +01:00
def regenerate_wrapper ( text , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , name1 , name2 , context , check , history_size ) :
2023-02-08 02:08:21 +01:00
last = history [ ' visible ' ] . pop ( )
history [ ' internal ' ] . pop ( )
text = last [ 0 ]
if args . cai_chat :
2023-02-08 03:11:04 +01:00
for i in cai_chatbot_wrapper ( text , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , name1 , name2 , context , check , history_size ) :
2023-02-08 02:08:21 +01:00
yield i
else :
2023-02-08 03:11:04 +01:00
for i in chatbot_wrapper ( text , tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , name1 , name2 , context , check , history_size ) :
2023-02-08 02:08:21 +01:00
yield i
def remove_last_message ( name1 , name2 ) :
if not history [ ' internal ' ] [ - 1 ] [ 0 ] == ' <|BEGIN-VISIBLE-CHAT|> ' :
last = history [ ' visible ' ] . pop ( )
history [ ' internal ' ] . pop ( )
else :
last = [ ' ' , ' ' ]
if args . cai_chat :
return generate_chat_html ( history [ ' visible ' ] , name1 , name2 , character ) , last [ 0 ]
else :
return history [ ' visible ' ] , last [ 0 ]
def send_last_reply_to_input ( ) :
if len ( history [ ' visible ' ] ) > 0 :
return history [ ' visible ' ] [ - 1 ] [ 1 ]
else :
return ' '
def replace_last_reply ( text , name1 , name2 ) :
if len ( history [ ' visible ' ] ) > 0 :
history [ ' visible ' ] [ - 1 ] [ 1 ] = text
history [ ' internal ' ] [ - 1 ] [ 1 ] = apply_extensions ( text , " input " )
if args . cai_chat :
return generate_chat_html ( history [ ' visible ' ] , name1 , name2 , character )
else :
return history [ ' visible ' ]
def clear_html ( ) :
return generate_chat_html ( [ ] , " " , " " , character )
def clear_chat_log ( _character , name1 , name2 ) :
global history
if _character != ' None ' :
for i in range ( len ( history [ ' internal ' ] ) ) :
if ' <|BEGIN-VISIBLE-CHAT|> ' in history [ ' internal ' ] [ i ] [ 0 ] :
history [ ' visible ' ] = [ [ ' ' , history [ ' internal ' ] [ i ] [ 1 ] ] ]
history [ ' internal ' ] = history [ ' internal ' ] [ : i + 1 ]
break
else :
history [ ' internal ' ] = [ ]
history [ ' visible ' ] = [ ]
if args . cai_chat :
return generate_chat_html ( history [ ' visible ' ] , name1 , name2 , character )
else :
return history [ ' visible ' ]
2023-01-27 04:40:39 +01:00
2023-02-08 02:08:21 +01:00
def redraw_html ( name1 , name2 ) :
global history
return generate_chat_html ( history [ ' visible ' ] , name1 , name2 , character )
def tokenize_dialogue ( dialogue , name1 , name2 ) :
_history = [ ]
dialogue = re . sub ( ' <START> ' , ' ' , dialogue )
dialogue = re . sub ( ' <start> ' , ' ' , dialogue )
dialogue = re . sub ( ' ( \n |^)[Aa]non: ' , ' \\ 1You: ' , dialogue )
dialogue = re . sub ( ' ( \n |^) \ [CHARACTER \ ]: ' , f ' \\ g<1> { name2 } : ' , dialogue )
idx = [ m . start ( ) for m in re . finditer ( f " (^| \n )( { name1 } | { name2 } ): " , dialogue ) ]
if len ( idx ) == 0 :
return _history
messages = [ ]
for i in range ( len ( idx ) - 1 ) :
messages . append ( dialogue [ idx [ i ] : idx [ i + 1 ] ] . strip ( ) )
messages . append ( dialogue [ idx [ - 1 ] : ] . strip ( ) )
entry = [ ' ' , ' ' ]
for i in messages :
if i . startswith ( f ' { name1 } : ' ) :
entry [ 0 ] = i [ len ( f ' { name1 } : ' ) : ] . strip ( )
elif i . startswith ( f ' { name2 } : ' ) :
entry [ 1 ] = i [ len ( f ' { name2 } : ' ) : ] . strip ( )
if not ( len ( entry [ 0 ] ) == 0 and len ( entry [ 1 ] ) == 0 ) :
_history . append ( entry )
entry = [ ' ' , ' ' ]
print ( f " \033 [1;32;1m \n Dialogue tokenized to: \033 [0;37;0m \n " , end = ' ' )
for row in _history :
for column in row :
print ( " \n " )
for line in column . strip ( ) . split ( ' \n ' ) :
print ( " | " + line + " \n " )
print ( " | \n " )
print ( " ------------------------------ " )
return _history
def save_history ( ) :
fname = f " { character or ' ' } { ' _ ' if character else ' ' } { datetime . now ( ) . strftime ( ' % Y % m %d - % H % M % S ' ) } .json "
if not Path ( ' logs ' ) . exists ( ) :
Path ( ' logs ' ) . mkdir ( )
with open ( Path ( f ' logs/ { fname } ' ) , ' w ' ) as f :
f . write ( json . dumps ( { ' data ' : history [ ' internal ' ] , ' data_visible ' : history [ ' visible ' ] } ) )
return Path ( f ' logs/ { fname } ' )
def load_history ( file , name1 , name2 ) :
global history
file = file . decode ( ' utf-8 ' )
try :
j = json . loads ( file )
if ' data ' in j :
history [ ' internal ' ] = j [ ' data ' ]
if ' data_visible ' in j :
history [ ' visible ' ] = j [ ' data_visible ' ]
else :
history [ ' visible ' ] = copy . deepcopy ( history [ ' internal ' ] )
# Compatibility with Pygmalion AI's official web UI
elif ' chat ' in j :
history [ ' internal ' ] = [ ' : ' . join ( x . split ( ' : ' ) [ 1 : ] ) . strip ( ) for x in j [ ' chat ' ] ]
if len ( j [ ' chat ' ] ) > 0 and j [ ' chat ' ] [ 0 ] . startswith ( f ' { name2 } : ' ) :
history [ ' internal ' ] = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , history [ ' internal ' ] [ 0 ] ] ] + [ [ history [ ' internal ' ] [ i ] , history [ ' internal ' ] [ i + 1 ] ] for i in range ( 1 , len ( history [ ' internal ' ] ) - 1 , 2 ) ]
history [ ' visible ' ] = copy . deepcopy ( history [ ' internal ' ] )
history [ ' visible ' ] [ 0 ] [ 0 ] = ' '
else :
history [ ' internal ' ] = [ [ history [ ' internal ' ] [ i ] , history [ ' internal ' ] [ i + 1 ] ] for i in range ( 0 , len ( history [ ' internal ' ] ) - 1 , 2 ) ]
history [ ' visible ' ] = copy . deepcopy ( history [ ' internal ' ] )
except :
history [ ' internal ' ] = tokenize_dialogue ( file , name1 , name2 )
history [ ' visible ' ] = copy . deepcopy ( history [ ' internal ' ] )
def load_character ( _character , name1 , name2 ) :
global history , character
context = " "
history [ ' internal ' ] = [ ]
history [ ' visible ' ] = [ ]
if _character != ' None ' :
character = _character
data = json . loads ( open ( Path ( f ' characters/ { _character } .json ' ) , ' r ' ) . read ( ) )
name2 = data [ ' char_name ' ]
if ' char_persona ' in data and data [ ' char_persona ' ] != ' ' :
context + = f " { data [ ' char_name ' ] } ' s Persona: { data [ ' char_persona ' ] } \n "
if ' world_scenario ' in data and data [ ' world_scenario ' ] != ' ' :
context + = f " Scenario: { data [ ' world_scenario ' ] } \n "
context = f " { context . strip ( ) } \n <START> \n "
if ' example_dialogue ' in data and data [ ' example_dialogue ' ] != ' ' :
history [ ' internal ' ] = tokenize_dialogue ( data [ ' example_dialogue ' ] , name1 , name2 )
if ' char_greeting ' in data and len ( data [ ' char_greeting ' ] . strip ( ) ) > 0 :
history [ ' internal ' ] + = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , data [ ' char_greeting ' ] ] ]
history [ ' visible ' ] + = [ [ ' ' , apply_extensions ( data [ ' char_greeting ' ] , " output " ) ] ]
else :
history [ ' internal ' ] + = [ [ ' <|BEGIN-VISIBLE-CHAT|> ' , " Hello there! " ] ]
history [ ' visible ' ] + = [ [ ' ' , " Hello there! " ] ]
else :
character = None
context = settings [ ' context_pygmalion ' ]
name2 = settings [ ' name2_pygmalion ' ]
if args . cai_chat :
return name2 , context , generate_chat_html ( history [ ' visible ' ] , name1 , name2 , character )
else :
return name2 , context , history [ ' visible ' ]
def upload_character ( json_file , img , tavern = False ) :
json_file = json_file if type ( json_file ) == str else json_file . decode ( ' utf-8 ' )
data = json . loads ( json_file )
outfile_name = data [ " char_name " ]
i = 1
while Path ( f ' characters/ { outfile_name } .json ' ) . exists ( ) :
outfile_name = f ' { data [ " char_name " ] } _ { i : 03d } '
i + = 1
if tavern :
outfile_name = f ' TavernAI- { outfile_name } '
with open ( Path ( f ' characters/ { outfile_name } .json ' ) , ' w ' ) as f :
f . write ( json_file )
if img is not None :
img = Image . open ( io . BytesIO ( img ) )
img . save ( Path ( f ' characters/ { outfile_name } .png ' ) )
print ( f ' New character saved to " characters/ { outfile_name } .json " . ' )
return outfile_name
def upload_tavern_character ( img , name1 , name2 ) :
_img = Image . open ( io . BytesIO ( img ) )
_img . getexif ( )
decoded_string = base64 . b64decode ( _img . info [ ' chara ' ] )
_json = json . loads ( decoded_string )
_json = { " char_name " : _json [ ' name ' ] , " char_persona " : _json [ ' description ' ] , " char_greeting " : _json [ " first_mes " ] , " example_dialogue " : _json [ ' mes_example ' ] , " world_scenario " : _json [ ' scenario ' ] }
_json [ ' example_dialogue ' ] = _json [ ' example_dialogue ' ] . replace ( ' {{ user}} ' , name1 ) . replace ( ' {{ char}} ' , _json [ ' char_name ' ] )
return upload_character ( json . dumps ( _json ) , img , tavern = True )
def upload_your_profile_picture ( img ) :
img = Image . open ( io . BytesIO ( img ) )
img . save ( Path ( f ' img_me.png ' ) )
print ( f ' Profile picture saved to " img_me.png " ' )
# Global variables
2023-01-22 04:49:59 +01:00
available_models = get_available_models ( )
available_presets = get_available_presets ( )
available_characters = get_available_characters ( )
2023-01-27 04:40:39 +01:00
available_extensions = get_available_extensions ( )
extension_state = { }
if args . extensions is not None :
for i , ext in enumerate ( args . extensions . split ( ' , ' ) ) :
if ext in available_extensions :
2023-01-27 14:53:05 +01:00
print ( f ' Loading the extension " { ext } " ... ' , end = ' ' )
ext_string = f " extensions. { ext } .script "
exec ( f " import { ext_string } " )
2023-01-27 04:40:39 +01:00
extension_state [ ext ] = [ True , i ]
2023-01-27 14:53:05 +01:00
print ( f ' Ok. ' )
2023-01-22 04:49:59 +01:00
2023-01-06 23:56:44 +01:00
# Choosing the default model
if args . model is not None :
model_name = args . model
else :
2023-01-07 02:05:37 +01:00
if len ( available_models ) == 0 :
2023-01-06 23:56:44 +01:00
print ( " No models are available! Please download at least one. " )
2023-01-30 18:17:12 +01:00
sys . exit ( 0 )
2023-01-06 23:56:44 +01:00
elif len ( available_models ) == 1 :
i = 0
else :
print ( " The following models are available: \n " )
for i , model in enumerate ( available_models ) :
print ( f " { i + 1 } . { model } " )
print ( f " \n Which one do you want to load? 1- { len ( available_models ) } \n " )
i = int ( input ( ) ) - 1
2023-01-09 16:56:54 +01:00
print ( )
2023-01-06 23:56:44 +01:00
model_name = available_models [ i ]
2022-12-21 17:27:31 +01:00
model , tokenizer = load_model ( model_name )
2023-01-22 04:49:59 +01:00
loaded_preset = None
2023-01-06 23:56:44 +01:00
2023-01-09 00:10:31 +01:00
# UI settings
2023-01-22 04:49:59 +01:00
default_text = settings [ ' prompt_gpt4chan ' ] if model_name . lower ( ) . startswith ( ( ' gpt4chan ' , ' gpt-4chan ' , ' 4chan ' ) ) else settings [ ' prompt ' ]
2023-01-15 19:23:41 +01:00
description = f " \n \n # Text generation lab \n Generate text using Large Language Models. \n "
2023-01-22 04:02:46 +01:00
css = " .my-4 { margin-top: 0} .py-6 { padding-top: 2.5rem} #refresh-button { flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label { min-height: 0} "
2023-01-29 16:02:44 +01:00
buttons = { }
2023-01-29 18:27:22 +01:00
gen_events = [ ]
2023-01-22 04:49:59 +01:00
2023-02-08 02:08:21 +01:00
suffix = ' _pygmalion ' if ' pygmalion ' in model_name . lower ( ) else ' '
history = { ' internal ' : [ ] , ' visible ' : [ ] }
character = None
2023-01-19 18:03:47 +01:00
2023-02-08 02:08:21 +01:00
if args . chat or args . cai_chat :
2023-02-03 14:00:05 +01:00
with gr . Blocks ( css = css + " .h- \ [40vh \ ] { height: 66.67vh} .gradio-container { max-width: 800px; margin-left: auto; margin-right: auto} .w-screen { width: unset} " , analytics_enabled = False ) as interface :
2023-01-15 22:16:46 +01:00
if args . cai_chat :
2023-01-29 18:27:22 +01:00
display = gr . HTML ( value = generate_chat_html ( [ ] , " " , " " , character ) )
2023-01-15 22:16:46 +01:00
else :
2023-01-29 18:27:22 +01:00
display = gr . Chatbot ( )
2023-01-23 18:04:01 +01:00
textbox = gr . Textbox ( label = ' Input ' )
2023-01-09 21:23:43 +01:00
with gr . Row ( ) :
2023-01-29 16:02:44 +01:00
buttons [ " Stop " ] = gr . Button ( " Stop " )
2023-02-05 02:53:42 +01:00
buttons [ " Generate " ] = gr . Button ( " Generate " )
2023-01-29 16:02:44 +01:00
buttons [ " Regenerate " ] = gr . Button ( " Regenerate " )
2023-02-05 02:53:42 +01:00
with gr . Row ( ) :
buttons [ " Impersonate " ] = gr . Button ( " Impersonate " )
2023-01-29 16:02:44 +01:00
buttons [ " Remove last " ] = gr . Button ( " Remove last " )
buttons [ " Clear " ] = gr . Button ( " Clear history " )
with gr . Row ( ) :
buttons [ " Send last reply to input " ] = gr . Button ( " Send last reply to input " )
buttons [ " Replace last reply " ] = gr . Button ( " Replace last reply " )
2023-01-13 19:02:17 +01:00
2023-01-15 22:16:46 +01:00
with gr . Row ( ) :
2023-01-08 02:52:46 +01:00
with gr . Column ( ) :
2023-02-08 02:08:21 +01:00
max_new_tokens = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-01-15 22:16:46 +01:00
with gr . Column ( ) :
2023-01-22 07:15:35 +01:00
history_size_slider = gr . Slider ( minimum = settings [ ' history_size_min ' ] , maximum = settings [ ' history_size_max ' ] , step = 1 , label = ' Chat history size in prompt (0 for no limit) ' , value = settings [ ' history_size ' ] )
2023-02-08 02:08:21 +01:00
2023-02-08 03:11:04 +01:00
preset_menu , do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping = create_settings_menus ( )
2023-01-15 22:16:46 +01:00
2023-01-19 22:58:45 +01:00
name1 = gr . Textbox ( value = settings [ f ' name1 { suffix } ' ] , lines = 1 , label = ' Your name ' )
name2 = gr . Textbox ( value = settings [ f ' name2 { suffix } ' ] , lines = 1 , label = ' Bot \' s name ' )
context = gr . Textbox ( value = settings [ f ' context { suffix } ' ] , lines = 2 , label = ' Context ' )
2023-01-15 22:16:46 +01:00
with gr . Row ( ) :
2023-01-22 04:02:46 +01:00
character_menu = gr . Dropdown ( choices = available_characters , value = " None " , label = ' Character ' )
create_refresh_button ( character_menu , lambda : None , lambda : { " choices " : get_available_characters ( ) } , " refresh-button " )
2023-01-19 20:46:46 +01:00
with gr . Row ( ) :
2023-01-19 22:58:45 +01:00
check = gr . Checkbox ( value = settings [ f ' stop_at_newline { suffix } ' ] , label = ' Stop generating at new line character? ' )
2023-01-19 18:03:47 +01:00
with gr . Row ( ) :
2023-01-29 00:28:08 +01:00
with gr . Tab ( ' Chat history ' ) :
with gr . Row ( ) :
with gr . Column ( ) :
gr . Markdown ( ' Upload ' )
upload = gr . File ( type = ' binary ' )
with gr . Column ( ) :
gr . Markdown ( ' Download ' )
download = gr . File ( )
2023-01-29 16:02:44 +01:00
buttons [ " Download " ] = gr . Button ( value = " Click me " )
2023-01-25 19:45:25 +01:00
with gr . Tab ( ' Upload character ' ) :
2023-01-28 23:16:37 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
gr . Markdown ( ' 1. Select the JSON file ' )
upload_char = gr . File ( type = ' binary ' )
with gr . Column ( ) :
gr . Markdown ( ' 2. Select your character \' s profile picture (optional) ' )
upload_img = gr . File ( type = ' binary ' )
2023-01-29 16:02:44 +01:00
buttons [ " Upload character " ] = gr . Button ( value = " Submit " )
2023-01-28 23:16:37 +01:00
with gr . Tab ( ' Upload your profile picture ' ) :
upload_img_me = gr . File ( type = ' binary ' )
2023-01-29 00:18:23 +01:00
with gr . Tab ( ' Upload TavernAI Character Card ' ) :
upload_img_tavern = gr . File ( type = ' binary ' )
2023-01-15 22:16:46 +01:00
2023-01-29 03:00:51 +01:00
if args . extensions is not None :
2023-01-29 16:05:18 +01:00
create_extensions_block ( )
2023-01-29 03:00:51 +01:00
2023-02-08 03:11:04 +01:00
input_params = [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping , name1 , name2 , context , check , history_size_slider ]
2023-01-15 16:20:04 +01:00
if args . cai_chat :
2023-01-29 18:27:22 +01:00
gen_events . append ( buttons [ " Generate " ] . click ( cai_chatbot_wrapper , input_params , display , show_progress = args . no_stream , api_name = " textgen " ) )
gen_events . append ( textbox . submit ( cai_chatbot_wrapper , input_params , display , show_progress = args . no_stream ) )
2023-01-15 16:20:04 +01:00
else :
2023-01-29 18:27:22 +01:00
gen_events . append ( buttons [ " Generate " ] . click ( chatbot_wrapper , input_params , display , show_progress = args . no_stream , api_name = " textgen " ) )
gen_events . append ( textbox . submit ( chatbot_wrapper , input_params , display , show_progress = args . no_stream ) )
gen_events . append ( buttons [ " Regenerate " ] . click ( regenerate_wrapper , input_params , display , show_progress = args . no_stream ) )
2023-01-30 03:05:17 +01:00
gen_events . append ( buttons [ " Impersonate " ] . click ( impersonate_wrapper , input_params , textbox , show_progress = args . no_stream ) )
2023-01-29 16:02:44 +01:00
buttons [ " Send last reply to input " ] . click ( send_last_reply_to_input , [ ] , textbox , show_progress = args . no_stream )
2023-01-29 18:27:22 +01:00
buttons [ " Replace last reply " ] . click ( replace_last_reply , [ textbox , name1 , name2 ] , display , show_progress = args . no_stream )
buttons [ " Clear " ] . click ( clear_chat_log , [ character_menu , name1 , name2 ] , display )
buttons [ " Remove last " ] . click ( remove_last_message , [ name1 , name2 ] , [ display , textbox ] , show_progress = False )
buttons [ " Stop " ] . click ( None , None , None , cancels = gen_events )
buttons [ " Download " ] . click ( save_history , inputs = [ ] , outputs = [ download ] )
buttons [ " Upload character " ] . click ( upload_character , [ upload_char , upload_img ] , [ character_menu ] )
2023-01-29 16:02:44 +01:00
for i in [ " Generate " , " Regenerate " , " Replace last reply " ] :
buttons [ i ] . click ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-08 05:33:45 +01:00
textbox . submit ( lambda x : " " , textbox , textbox , show_progress = False )
2023-01-29 18:27:22 +01:00
character_menu . change ( load_character , [ character_menu , name1 , name2 ] , [ name2 , context , display ] )
2023-01-29 00:18:23 +01:00
upload_img_tavern . upload ( upload_tavern_character , [ upload_img_tavern , name1 , name2 ] , [ character_menu ] )
2023-02-04 02:43:02 +01:00
upload . upload ( load_history , [ upload , name1 , name2 ] , [ ] )
2023-01-28 23:16:37 +01:00
upload_img_me . upload ( upload_your_profile_picture , [ upload_img_me ] , [ ] )
2023-01-19 20:46:46 +01:00
2023-01-19 19:05:42 +01:00
if args . cai_chat :
2023-01-29 18:27:22 +01:00
upload . upload ( redraw_html , [ name1 , name2 ] , [ display ] )
upload_img_me . upload ( redraw_html , [ name1 , name2 ] , [ display ] )
2023-01-19 19:05:42 +01:00
else :
2023-01-29 18:27:22 +01:00
upload . upload ( lambda : history [ ' visible ' ] , [ ] , [ display ] )
upload_img_me . upload ( lambda : history [ ' visible ' ] , [ ] , [ display ] )
2023-01-19 18:03:47 +01:00
2023-01-19 02:44:47 +01:00
elif args . notebook :
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
with gr . Tab ( ' Raw ' ) :
textbox = gr . Textbox ( value = default_text , lines = 23 )
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-02-08 02:08:21 +01:00
2023-01-29 16:02:44 +01:00
buttons [ " Generate " ] = gr . Button ( " Generate " )
buttons [ " Stop " ] = gr . Button ( " Stop " )
2023-01-11 05:33:57 +01:00
2023-02-08 02:08:21 +01:00
max_new_tokens = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-02-08 03:11:04 +01:00
preset_menu , do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping = create_settings_menus ( )
2023-01-19 02:44:47 +01:00
2023-01-29 13:48:18 +01:00
if args . extensions is not None :
2023-01-29 16:05:18 +01:00
create_extensions_block ( )
2023-01-29 13:48:18 +01:00
2023-02-08 03:11:04 +01:00
gen_events . append ( buttons [ " Generate " ] . click ( generate_reply , [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping ] , [ textbox , markdown , html ] , show_progress = args . no_stream , api_name = " textgen " ) )
gen_events . append ( textbox . submit ( generate_reply , [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping ] , [ textbox , markdown , html ] , show_progress = args . no_stream ) )
2023-01-29 18:27:22 +01:00
buttons [ " Stop " ] . click ( None , None , None , cancels = gen_events )
2023-01-19 02:44:47 +01:00
else :
2023-01-09 00:10:31 +01:00
with gr . Blocks ( css = css , analytics_enabled = False ) as interface :
gr . Markdown ( description )
2023-01-07 02:05:37 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
textbox = gr . Textbox ( value = default_text , lines = 15 , label = ' Input ' )
2023-02-08 02:08:21 +01:00
max_new_tokens = gr . Slider ( minimum = settings [ ' max_new_tokens_min ' ] , maximum = settings [ ' max_new_tokens_max ' ] , step = 1 , label = ' max_new_tokens ' , value = settings [ ' max_new_tokens ' ] )
2023-01-29 16:02:44 +01:00
buttons [ " Generate " ] = gr . Button ( " Generate " )
2023-01-19 02:44:47 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
2023-01-29 18:27:22 +01:00
buttons [ " Continue " ] = gr . Button ( " Continue " )
2023-01-19 02:44:47 +01:00
with gr . Column ( ) :
2023-01-29 16:02:44 +01:00
buttons [ " Stop " ] = gr . Button ( " Stop " )
2023-02-08 02:08:21 +01:00
2023-02-08 03:11:04 +01:00
preset_menu , do_sample , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping = create_settings_menus ( )
2023-01-29 13:48:18 +01:00
if args . extensions is not None :
2023-01-29 16:05:18 +01:00
create_extensions_block ( )
2023-01-29 13:48:18 +01:00
2023-01-07 02:05:37 +01:00
with gr . Column ( ) :
with gr . Tab ( ' Raw ' ) :
2023-01-11 05:36:11 +01:00
output_textbox = gr . Textbox ( lines = 15 , label = ' Output ' )
2023-01-07 02:05:37 +01:00
with gr . Tab ( ' Markdown ' ) :
markdown = gr . Markdown ( )
2023-01-07 03:14:08 +01:00
with gr . Tab ( ' HTML ' ) :
html = gr . HTML ( )
2023-01-07 02:05:37 +01:00
2023-02-08 03:11:04 +01:00
gen_events . append ( buttons [ " Generate " ] . click ( generate_reply , [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream , api_name = " textgen " ) )
gen_events . append ( textbox . submit ( generate_reply , [ textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream ) )
gen_events . append ( buttons [ " Continue " ] . click ( generate_reply , [ output_textbox , max_new_tokens , do_sample , max_new_tokens , temperature , top_p , typical_p , repetition_penalty , top_k , min_length , no_repeat_ngram_size , num_beams , length_penalty , early_stopping ] , [ output_textbox , markdown , html ] , show_progress = args . no_stream ) )
2023-01-29 18:27:22 +01:00
buttons [ " Stop " ] . click ( None , None , None , cancels = gen_events )
2022-12-21 17:27:31 +01:00
2023-01-25 20:10:35 +01:00
interface . queue ( )
2023-01-21 03:45:16 +01:00
if args . listen :
2023-02-08 02:08:21 +01:00
interface . launch ( prevent_thread_lock = True , share = args . share , server_name = " 0.0.0.0 " , server_port = args . listen_port )
2023-01-21 03:45:16 +01:00
else :
2023-02-08 02:08:21 +01:00
interface . launch ( prevent_thread_lock = True , share = args . share , server_port = args . listen_port )
# I think that I will need this later
while True :
time . sleep ( 0.5 )