2023-02-23 16:05:25 +01:00
import argparse
2023-05-04 02:43:17 +02:00
import logging
2023-04-14 16:07:28 +02:00
from pathlib import Path
import yaml
2023-02-23 16:05:25 +01:00
model = None
tokenizer = None
2023-03-17 01:31:39 +01:00
model_name = " None "
2023-04-22 19:56:48 +02:00
model_type = None
2023-04-14 19:52:06 +02:00
lora_names = [ ]
2023-02-23 16:05:25 +01:00
soft_prompt_tensor = None
soft_prompt = False
2023-02-23 17:42:23 +01:00
2023-02-23 19:11:18 +01:00
# Chat variables
history = { ' internal ' : [ ] , ' visible ' : [ ] }
character = ' None '
2023-02-23 19:26:41 +01:00
stop_everything = False
2023-03-14 02:28:00 +01:00
processing_message = ' *Is typing...* '
2023-02-23 19:11:18 +01:00
2023-02-24 20:46:50 +01:00
# UI elements (buttons, sliders, HTML, etc)
gradio = { }
2023-04-24 08:05:47 +02:00
# For keeping the values of UI elements on page reload
persistent_interface_state = { }
2023-02-25 04:23:51 +01:00
# Generation input parameters
input_params = [ ]
2023-03-16 03:29:56 +01:00
# For restarting the interface
need_restart = False
2023-02-23 17:42:23 +01:00
settings = {
' max_new_tokens ' : 200 ,
' max_new_tokens_min ' : 1 ,
' max_new_tokens_max ' : 2000 ,
2023-03-31 17:22:07 +02:00
' seed ' : - 1 ,
2023-04-24 18:19:42 +02:00
' character ' : ' None ' ,
2023-03-23 17:36:00 +01:00
' name1 ' : ' You ' ,
' name2 ' : ' Assistant ' ,
' context ' : ' This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions. ' ,
2023-04-12 23:30:43 +02:00
' greeting ' : ' ' ,
2023-04-26 08:21:53 +02:00
' turn_template ' : ' ' ,
2023-04-11 17:30:06 +02:00
' custom_stopping_strings ' : ' ' ,
2023-03-18 14:55:57 +01:00
' stop_at_newline ' : False ,
2023-04-10 21:44:22 +02:00
' add_bos_token ' : True ,
2023-04-11 23:46:06 +02:00
' ban_eos_token ' : False ,
2023-04-16 19:24:49 +02:00
' skip_special_tokens ' : True ,
2023-04-11 23:46:06 +02:00
' truncation_length ' : 2048 ,
' truncation_length_min ' : 0 ,
2023-04-19 22:35:38 +02:00
' truncation_length_max ' : 8192 ,
2023-05-08 17:35:03 +02:00
' mode ' : ' chat ' ,
' chat_style ' : ' cai-chat ' ,
2023-04-14 16:07:28 +02:00
' instruction_template ' : ' None ' ,
2023-02-23 17:42:23 +01:00
' chat_prompt_size ' : 2048 ,
' chat_prompt_size_min ' : 0 ,
' chat_prompt_size_max ' : 2048 ,
2023-02-25 05:42:19 +01:00
' chat_generation_attempts ' : 1 ,
' chat_generation_attempts_min ' : 1 ,
' chat_generation_attempts_max ' : 5 ,
2023-02-28 06:20:11 +01:00
' default_extensions ' : [ ] ,
' chat_default_extensions ' : [ " gallery " ] ,
2023-03-02 15:25:04 +01:00
' presets ' : {
2023-04-10 20:48:07 +02:00
' default ' : ' Default ' ,
2023-04-24 01:32:22 +02:00
' .*(alpaca|llama|llava) ' : " LLaMA-Precise " ,
2023-03-30 22:34:44 +02:00
' .*pygmalion ' : ' NovelAI-Storywriter ' ,
2023-03-30 03:40:04 +02:00
' .*RWKV ' : ' Naive ' ,
2023-05-05 04:19:23 +02:00
' .*moss ' : ' MOSS ' ,
2023-03-02 15:25:04 +01:00
} ,
' prompts ' : {
2023-03-30 03:40:04 +02:00
' default ' : ' QA ' ,
' .*(gpt4chan|gpt-4chan|4chan) ' : ' GPT-4chan ' ,
' .*oasst ' : ' Open Assistant ' ,
' .*alpaca ' : " Alpaca " ,
2023-03-17 15:24:52 +01:00
} ,
' lora_prompts ' : {
2023-03-30 03:40:04 +02:00
' default ' : ' QA ' ,
2023-04-14 19:52:06 +02:00
' .*alpaca ' : " Alpaca " ,
2023-03-02 15:25:04 +01:00
}
2023-02-23 17:42:23 +01:00
}
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-03-04 05:04:02 +01:00
def str2bool ( v ) :
if isinstance ( v , bool ) :
return v
if v . lower ( ) in ( ' yes ' , ' true ' , ' t ' , ' y ' , ' 1 ' ) :
return True
elif v . lower ( ) in ( ' no ' , ' false ' , ' f ' , ' n ' , ' 0 ' ) :
return False
else :
raise argparse . ArgumentTypeError ( ' Boolean value expected. ' )
2023-04-07 05:15:45 +02:00
parser = argparse . ArgumentParser ( formatter_class = lambda prog : argparse . HelpFormatter ( prog , max_help_position = 54 ) )
2023-04-01 02:18:05 +02:00
# Basic settings
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --notebook ' , action = ' store_true ' , help = ' Launch the web UI in notebook mode, where the output is written to the same text box as the input. ' )
2023-04-05 16:49:59 +02:00
parser . add_argument ( ' --chat ' , action = ' store_true ' , help = ' Launch the web UI in chat mode with a style similar to the Character.AI website. ' )
2023-04-24 18:19:42 +02:00
parser . add_argument ( ' --character ' , type = str , help = ' The name of the character to load in chat mode by default. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --model ' , type = str , help = ' Name of the model to load by default. ' )
2023-04-26 03:58:48 +02:00
parser . add_argument ( ' --lora ' , type = str , nargs = " + " , help = ' The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( " --model-dir " , type = str , default = ' models/ ' , help = " Path to directory with all the models " )
parser . add_argument ( " --lora-dir " , type = str , default = ' loras/ ' , help = " Path to directory with all the loras " )
2023-04-13 02:24:26 +02:00
parser . add_argument ( ' --model-menu ' , action = ' store_true ' , help = ' Show a model menu in the terminal when the web UI is first launched. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --no-stream ' , action = ' store_true ' , help = ' Don \' t stream the text output in real time. ' )
parser . add_argument ( ' --settings ' , type = str , help = ' Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag. ' )
parser . add_argument ( ' --extensions ' , type = str , nargs = " + " , help = ' The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. ' )
parser . add_argument ( ' --verbose ' , action = ' store_true ' , help = ' Print the prompts to the terminal. ' )
# Accelerate/transformers
2023-04-10 22:29:00 +02:00
parser . add_argument ( ' --cpu ' , action = ' store_true ' , help = ' Use the CPU to generate text. Warning: Training on CPU is extremely slow. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --auto-devices ' , action = ' store_true ' , help = ' Automatically split the model across the available GPU(s) and CPU. ' )
2023-04-01 18:56:47 +02:00
parser . add_argument ( ' --gpu-memory ' , type = str , nargs = " + " , help = ' Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB. ' )
parser . add_argument ( ' --cpu-memory ' , type = str , help = ' Maximum CPU memory in GiB to allocate for offloaded weights. Same as above. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --disk ' , action = ' store_true ' , help = ' If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. ' )
parser . add_argument ( ' --disk-cache-dir ' , type = str , default = " cache " , help = ' Directory to save the disk cache to. Defaults to " cache " . ' )
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --load-in-8bit ' , action = ' store_true ' , help = ' Load the model with 8-bit precision. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --bf16 ' , action = ' store_true ' , help = ' Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. ' )
parser . add_argument ( ' --no-cache ' , action = ' store_true ' , help = ' Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost. ' )
2023-04-10 04:08:40 +02:00
parser . add_argument ( ' --xformers ' , action = ' store_true ' , help = " Use xformer ' s memory efficient attention. This should increase your tokens/s. " )
parser . add_argument ( ' --sdp-attention ' , action = ' store_true ' , help = " Use torch 2.0 ' s sdp attention. " )
2023-04-17 00:15:03 +02:00
parser . add_argument ( ' --trust-remote-code ' , action = ' store_true ' , help = " Set trust_remote_code=True while loading a model. Necessary for ChatGLM. " )
2023-03-26 05:11:33 +02:00
2023-04-01 02:18:05 +02:00
# llama.cpp
2023-05-02 23:25:28 +02:00
parser . add_argument ( ' --threads ' , type = int , default = 0 , help = ' Number of threads to use. ' )
parser . add_argument ( ' --n_batch ' , type = int , default = 512 , help = ' Maximum number of prompt tokens to batch together when calling llama_eval. ' )
2023-05-03 14:50:31 +02:00
parser . add_argument ( ' --no-mmap ' , action = ' store_true ' , help = ' Prevent mmap from being used. ' )
2023-05-02 23:25:28 +02:00
parser . add_argument ( ' --mlock ' , action = ' store_true ' , help = ' Force the system to keep the model in RAM. ' )
2023-04-01 02:18:05 +02:00
# GPTQ
2023-04-17 15:55:35 +02:00
parser . add_argument ( ' --wbits ' , type = int , default = 0 , help = ' Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. ' )
parser . add_argument ( ' --model_type ' , type = str , help = ' Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. ' )
parser . add_argument ( ' --groupsize ' , type = int , default = - 1 , help = ' Group size. ' )
parser . add_argument ( ' --pre_layer ' , type = int , default = 0 , help = ' The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. ' )
2023-05-04 20:17:20 +02:00
parser . add_argument ( ' --checkpoint ' , type = str , help = ' The path to the quantized checkpoint file. If not specified, it will be automatically detected. ' )
2023-04-17 15:55:35 +02:00
parser . add_argument ( ' --monkey-patch ' , action = ' store_true ' , help = ' Apply the monkey patch for using LoRAs with quantized models. ' )
2023-04-22 17:27:30 +02:00
parser . add_argument ( ' --quant_attn ' , action = ' store_true ' , help = ' (triton) Enable quant attention. ' )
parser . add_argument ( ' --warmup_autotune ' , action = ' store_true ' , help = ' (triton) Enable warmup autotune. ' )
parser . add_argument ( ' --fused_mlp ' , action = ' store_true ' , help = ' (triton) Enable fused mlp. ' )
2023-03-26 05:11:33 +02:00
2023-04-01 02:18:05 +02:00
# FlexGen
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --flexgen ' , action = ' store_true ' , help = ' Enable the use of FlexGen offloading. ' )
2023-02-24 12:55:09 +01:00
parser . add_argument ( ' --percent ' , type = int , nargs = " + " , default = [ 0 , 100 , 100 , 0 , 100 , 0 ] , help = ' FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). ' )
2023-02-23 16:05:25 +01:00
parser . add_argument ( " --compress-weight " , action = " store_true " , help = " FlexGen: activate weight compression. " )
2023-03-04 05:04:02 +01:00
parser . add_argument ( " --pin-weight " , type = str2bool , nargs = " ? " , const = True , default = True , help = " FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20 %% ). " )
2023-04-01 02:18:05 +02:00
# DeepSpeed
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --deepspeed ' , action = ' store_true ' , help = ' Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. ' )
parser . add_argument ( ' --nvme-offload-dir ' , type = str , help = ' DeepSpeed: Directory to use for ZeRO-3 NVME offloading. ' )
parser . add_argument ( ' --local_rank ' , type = int , default = 0 , help = ' DeepSpeed: Optional argument for distributed setups. ' )
2023-04-01 02:18:05 +02:00
# RWKV
2023-03-07 00:12:54 +01:00
parser . add_argument ( ' --rwkv-strategy ' , type = str , default = None , help = ' RWKV: The strategy to use while loading the model. Examples: " cpu fp32 " , " cuda fp16 " , " cuda fp16i8 " . ' )
parser . add_argument ( ' --rwkv-cuda-on ' , action = ' store_true ' , help = ' RWKV: Compile the CUDA kernel for better performance. ' )
2023-04-01 02:18:05 +02:00
# Gradio
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --listen ' , action = ' store_true ' , help = ' Make the web UI reachable from your local network. ' )
2023-04-14 02:35:08 +02:00
parser . add_argument ( ' --listen-host ' , type = str , help = ' The hostname that the server will use. ' )
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --listen-port ' , type = int , help = ' The listening port that the server will use. ' )
parser . add_argument ( ' --share ' , action = ' store_true ' , help = ' Create a public URL. This is useful for running the web UI on Google Colab or similar. ' )
2023-03-13 16:44:18 +01:00
parser . add_argument ( ' --auto-launch ' , action = ' store_true ' , default = False , help = ' Open the web UI in the default browser upon launch. ' )
2023-03-28 04:39:26 +02:00
parser . add_argument ( " --gradio-auth-path " , type = str , help = ' Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: " u1:p1,u2:p2,u3:p3 " ' , default = None )
2023-04-01 02:18:05 +02:00
2023-04-23 20:52:43 +02:00
# API
parser . add_argument ( ' --api ' , action = ' store_true ' , help = ' Enable the API extension. ' )
parser . add_argument ( ' --public-api ' , action = ' store_true ' , help = ' Create a public URL for the API using Cloudfare. ' )
2023-02-23 16:05:25 +01:00
args = parser . parse_args ( )
2023-04-14 20:35:06 +02:00
args_defaults = parser . parse_args ( [ ] )
2023-03-14 11:56:31 +01:00
2023-04-05 16:49:59 +02:00
# Deprecation warnings for parameters that have been renamed
2023-04-12 22:09:56 +02:00
deprecated_dict = { }
2023-03-26 05:11:33 +02:00
for k in deprecated_dict :
2023-04-16 06:36:50 +02:00
if getattr ( args , k ) != deprecated_dict [ k ] [ 1 ] :
2023-05-04 02:43:17 +02:00
logging . warning ( f " -- { k } is deprecated and will be removed. Use -- { deprecated_dict [ k ] [ 0 ] } instead. " )
2023-04-16 06:36:50 +02:00
setattr ( args , deprecated_dict [ k ] [ 0 ] , getattr ( args , k ) )
2023-04-02 01:14:43 +02:00
2023-04-17 00:15:03 +02:00
# Security warnings
if args . trust_remote_code :
2023-05-04 02:43:17 +02:00
logging . warning ( " trust_remote_code is enabled. This is dangerous. " )
2023-04-18 00:34:28 +02:00
if args . share :
2023-05-04 02:43:17 +02:00
logging . warning ( " The gradio \" share link \" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous. " )
2023-04-17 00:15:03 +02:00
2023-04-23 20:52:43 +02:00
# Activating the API extension
if args . api or args . public_api :
if args . extensions is None :
args . extensions = [ ' api ' ]
elif ' api ' not in args . extensions :
args . extensions . append ( ' api ' )
2023-04-07 05:15:45 +02:00
2023-04-02 01:14:43 +02:00
def is_chat ( ) :
2023-04-05 16:49:59 +02:00
return args . chat
2023-04-14 16:07:28 +02:00
# Loading model-specific settings (default)
with Path ( f ' { args . model_dir } /config.yaml ' ) as p :
if p . exists ( ) :
model_config = yaml . safe_load ( open ( p , ' r ' ) . read ( ) )
else :
model_config = { }
# Applying user-defined model settings
with Path ( f ' { args . model_dir } /config-user.yaml ' ) as p :
if p . exists ( ) :
user_config = yaml . safe_load ( open ( p , ' r ' ) . read ( ) )
for k in user_config :
if k in model_config :
model_config [ k ] . update ( user_config [ k ] )
else :
model_config [ k ] = user_config [ k ]