2023-02-23 16:05:25 +01:00
import argparse
2023-05-12 11:09:45 +02:00
from collections import OrderedDict
2023-04-14 16:07:28 +02:00
from pathlib import Path
import yaml
2023-02-23 16:05:25 +01:00
2023-05-22 03:42:34 +02:00
from modules . logging_colors import logger
2023-05-24 14:38:20 +02:00
generation_lock = None
2023-02-23 16:05:25 +01:00
model = None
tokenizer = None
2023-03-17 01:31:39 +01:00
model_name = " None "
2023-04-22 19:56:48 +02:00
model_type = None
2023-04-14 19:52:06 +02:00
lora_names = [ ]
2023-02-23 16:05:25 +01:00
soft_prompt_tensor = None
soft_prompt = False
2023-02-23 17:42:23 +01:00
2023-02-23 19:11:18 +01:00
# Chat variables
history = { ' internal ' : [ ] , ' visible ' : [ ] }
character = ' None '
2023-02-23 19:26:41 +01:00
stop_everything = False
2023-03-14 02:28:00 +01:00
processing_message = ' *Is typing...* '
2023-02-23 19:11:18 +01:00
2023-02-24 20:46:50 +01:00
# UI elements (buttons, sliders, HTML, etc)
gradio = { }
2023-04-24 08:05:47 +02:00
# For keeping the values of UI elements on page reload
persistent_interface_state = { }
2023-05-11 21:12:46 +02:00
input_params = [ ] # Generation input parameters
reload_inputs = [ ] # Parameters for reloading the chat interface
2023-02-25 04:23:51 +01:00
2023-03-16 03:29:56 +01:00
# For restarting the interface
need_restart = False
2023-02-23 17:42:23 +01:00
settings = {
2023-05-23 00:45:11 +02:00
' dark_theme ' : False ,
2023-05-09 20:52:35 +02:00
' autoload_model ' : True ,
2023-02-23 17:42:23 +01:00
' max_new_tokens ' : 200 ,
' max_new_tokens_min ' : 1 ,
' max_new_tokens_max ' : 2000 ,
2023-03-31 17:22:07 +02:00
' seed ' : - 1 ,
2023-04-24 18:19:42 +02:00
' character ' : ' None ' ,
2023-03-23 17:36:00 +01:00
' name1 ' : ' You ' ,
' name2 ' : ' Assistant ' ,
2023-05-16 00:39:08 +02:00
' context ' : ' This is a conversation with your Assistant. It is a computer program designed to help you with various tasks such as answering questions, providing recommendations, and helping with decision making. You can ask it anything you want and it will do its best to give you accurate and relevant information. ' ,
2023-04-12 23:30:43 +02:00
' greeting ' : ' ' ,
2023-04-26 08:21:53 +02:00
' turn_template ' : ' ' ,
2023-04-11 17:30:06 +02:00
' custom_stopping_strings ' : ' ' ,
2023-03-18 14:55:57 +01:00
' stop_at_newline ' : False ,
2023-04-10 21:44:22 +02:00
' add_bos_token ' : True ,
2023-04-11 23:46:06 +02:00
' ban_eos_token ' : False ,
2023-04-16 19:24:49 +02:00
' skip_special_tokens ' : True ,
2023-04-11 23:46:06 +02:00
' truncation_length ' : 2048 ,
' truncation_length_min ' : 0 ,
2023-04-19 22:35:38 +02:00
' truncation_length_max ' : 8192 ,
2023-05-08 17:35:03 +02:00
' mode ' : ' chat ' ,
' chat_style ' : ' cai-chat ' ,
2023-04-14 16:07:28 +02:00
' instruction_template ' : ' None ' ,
2023-05-14 15:43:55 +02:00
' chat-instruct_command ' : ' Continue the chat dialogue below. Write a single reply for the character " <|character|> " . \n \n <|prompt|> ' ,
2023-02-23 17:42:23 +01:00
' chat_prompt_size ' : 2048 ,
' chat_prompt_size_min ' : 0 ,
' chat_prompt_size_max ' : 2048 ,
2023-02-25 05:42:19 +01:00
' chat_generation_attempts ' : 1 ,
' chat_generation_attempts_min ' : 1 ,
2023-05-12 17:50:29 +02:00
' chat_generation_attempts_max ' : 10 ,
2023-02-28 06:20:11 +01:00
' default_extensions ' : [ ] ,
2023-05-29 03:34:12 +02:00
' chat_default_extensions ' : [ ' gallery ' ] ,
' preset ' : ' LLaMA-Precise ' ,
' prompt ' : ' QA ' ,
2023-02-23 17:42:23 +01:00
}
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-03-04 05:04:02 +01:00
def str2bool ( v ) :
if isinstance ( v , bool ) :
return v
if v . lower ( ) in ( ' yes ' , ' true ' , ' t ' , ' y ' , ' 1 ' ) :
return True
elif v . lower ( ) in ( ' no ' , ' false ' , ' f ' , ' n ' , ' 0 ' ) :
return False
else :
raise argparse . ArgumentTypeError ( ' Boolean value expected. ' )
2023-04-07 05:15:45 +02:00
parser = argparse . ArgumentParser ( formatter_class = lambda prog : argparse . HelpFormatter ( prog , max_help_position = 54 ) )
2023-04-01 02:18:05 +02:00
# Basic settings
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --notebook ' , action = ' store_true ' , help = ' Launch the web UI in notebook mode, where the output is written to the same text box as the input. ' )
2023-04-05 16:49:59 +02:00
parser . add_argument ( ' --chat ' , action = ' store_true ' , help = ' Launch the web UI in chat mode with a style similar to the Character.AI website. ' )
2023-04-24 18:19:42 +02:00
parser . add_argument ( ' --character ' , type = str , help = ' The name of the character to load in chat mode by default. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --model ' , type = str , help = ' Name of the model to load by default. ' )
2023-04-26 03:58:48 +02:00
parser . add_argument ( ' --lora ' , type = str , nargs = " + " , help = ' The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( " --model-dir " , type = str , default = ' models/ ' , help = " Path to directory with all the models " )
parser . add_argument ( " --lora-dir " , type = str , default = ' loras/ ' , help = " Path to directory with all the loras " )
2023-04-13 02:24:26 +02:00
parser . add_argument ( ' --model-menu ' , action = ' store_true ' , help = ' Show a model menu in the terminal when the web UI is first launched. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --no-stream ' , action = ' store_true ' , help = ' Don \' t stream the text output in real time. ' )
2023-05-29 03:34:12 +02:00
parser . add_argument ( ' --settings ' , type = str , help = ' Load the default interface settings from this yaml file. See settings-template.yaml for an example. If you create a file called settings.yaml, this file will be loaded by default without the need to use the --settings flag. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --extensions ' , type = str , nargs = " + " , help = ' The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. ' )
parser . add_argument ( ' --verbose ' , action = ' store_true ' , help = ' Print the prompts to the terminal. ' )
# Accelerate/transformers
2023-04-10 22:29:00 +02:00
parser . add_argument ( ' --cpu ' , action = ' store_true ' , help = ' Use the CPU to generate text. Warning: Training on CPU is extremely slow. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --auto-devices ' , action = ' store_true ' , help = ' Automatically split the model across the available GPU(s) and CPU. ' )
2023-06-01 17:08:39 +02:00
parser . add_argument ( ' --gpu-memory ' , type = str , nargs = " + " , help = ' Maximum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB. ' )
2023-04-01 18:56:47 +02:00
parser . add_argument ( ' --cpu-memory ' , type = str , help = ' Maximum CPU memory in GiB to allocate for offloaded weights. Same as above. ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --disk ' , action = ' store_true ' , help = ' If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. ' )
parser . add_argument ( ' --disk-cache-dir ' , type = str , default = " cache " , help = ' Directory to save the disk cache to. Defaults to " cache " . ' )
2023-05-25 06:14:13 +02:00
parser . add_argument ( ' --load-in-8bit ' , action = ' store_true ' , help = ' Load the model with 8-bit precision (using bitsandbytes). ' )
2023-04-01 02:18:05 +02:00
parser . add_argument ( ' --bf16 ' , action = ' store_true ' , help = ' Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. ' )
parser . add_argument ( ' --no-cache ' , action = ' store_true ' , help = ' Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost. ' )
2023-04-10 04:08:40 +02:00
parser . add_argument ( ' --xformers ' , action = ' store_true ' , help = " Use xformer ' s memory efficient attention. This should increase your tokens/s. " )
parser . add_argument ( ' --sdp-attention ' , action = ' store_true ' , help = " Use torch 2.0 ' s sdp attention. " )
2023-05-29 15:20:18 +02:00
parser . add_argument ( ' --trust-remote-code ' , action = ' store_true ' , help = " Set trust_remote_code=True while loading a model. Necessary for ChatGLM and Falcon. " )
2023-03-26 05:11:33 +02:00
2023-05-25 06:14:13 +02:00
# Accelerate 4-bit
parser . add_argument ( ' --load-in-4bit ' , action = ' store_true ' , help = ' Load the model with 4-bit precision (using bitsandbytes). ' )
2023-05-25 20:05:53 +02:00
parser . add_argument ( ' --compute_dtype ' , type = str , default = " float16 " , help = " compute dtype for 4-bit. Valid options: bfloat16, float16, float32. " )
2023-05-25 06:14:13 +02:00
parser . add_argument ( ' --quant_type ' , type = str , default = " nf4 " , help = ' quant_type for 4-bit. Valid options: nf4, fp4. ' )
parser . add_argument ( ' --use_double_quant ' , action = ' store_true ' , help = ' use_double_quant for 4-bit. ' )
2023-04-01 02:18:05 +02:00
# llama.cpp
2023-05-02 23:25:28 +02:00
parser . add_argument ( ' --threads ' , type = int , default = 0 , help = ' Number of threads to use. ' )
parser . add_argument ( ' --n_batch ' , type = int , default = 512 , help = ' Maximum number of prompt tokens to batch together when calling llama_eval. ' )
2023-05-03 14:50:31 +02:00
parser . add_argument ( ' --no-mmap ' , action = ' store_true ' , help = ' Prevent mmap from being used. ' )
2023-05-02 23:25:28 +02:00
parser . add_argument ( ' --mlock ' , action = ' store_true ' , help = ' Force the system to keep the model in RAM. ' )
2023-05-16 01:19:55 +02:00
parser . add_argument ( ' --cache-capacity ' , type = str , help = ' Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. ' )
2023-05-15 03:58:11 +02:00
parser . add_argument ( ' --n-gpu-layers ' , type = int , default = 0 , help = ' Number of layers to offload to the GPU. ' )
2023-05-25 15:29:31 +02:00
parser . add_argument ( ' --n_ctx ' , type = int , default = 2048 , help = ' Size of the prompt context. ' )
parser . add_argument ( ' --llama_cpp_seed ' , type = int , default = 0 , help = ' Seed for llama-cpp models. Default 0 (random) ' )
2023-04-01 02:18:05 +02:00
# GPTQ
2023-04-17 15:55:35 +02:00
parser . add_argument ( ' --wbits ' , type = int , default = 0 , help = ' Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. ' )
parser . add_argument ( ' --model_type ' , type = str , help = ' Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. ' )
parser . add_argument ( ' --groupsize ' , type = int , default = - 1 , help = ' Group size. ' )
2023-05-17 15:41:09 +02:00
parser . add_argument ( ' --pre_layer ' , type = int , nargs = " + " , help = ' The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated by spaces, eg --pre_layer 30 60. ' )
2023-05-04 20:17:20 +02:00
parser . add_argument ( ' --checkpoint ' , type = str , help = ' The path to the quantized checkpoint file. If not specified, it will be automatically detected. ' )
2023-04-17 15:55:35 +02:00
parser . add_argument ( ' --monkey-patch ' , action = ' store_true ' , help = ' Apply the monkey patch for using LoRAs with quantized models. ' )
2023-04-22 17:27:30 +02:00
parser . add_argument ( ' --quant_attn ' , action = ' store_true ' , help = ' (triton) Enable quant attention. ' )
parser . add_argument ( ' --warmup_autotune ' , action = ' store_true ' , help = ' (triton) Enable warmup autotune. ' )
parser . add_argument ( ' --fused_mlp ' , action = ' store_true ' , help = ' (triton) Enable fused mlp. ' )
2023-03-26 05:11:33 +02:00
2023-05-17 16:12:12 +02:00
# AutoGPTQ
parser . add_argument ( ' --autogptq ' , action = ' store_true ' , help = ' Use AutoGPTQ for loading quantized models instead of the internal GPTQ loader. ' )
parser . add_argument ( ' --triton ' , action = ' store_true ' , help = ' Use triton. ' )
2023-04-01 02:18:05 +02:00
# FlexGen
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --flexgen ' , action = ' store_true ' , help = ' Enable the use of FlexGen offloading. ' )
2023-02-24 12:55:09 +01:00
parser . add_argument ( ' --percent ' , type = int , nargs = " + " , default = [ 0 , 100 , 100 , 0 , 100 , 0 ] , help = ' FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). ' )
2023-02-23 16:05:25 +01:00
parser . add_argument ( " --compress-weight " , action = " store_true " , help = " FlexGen: activate weight compression. " )
2023-03-04 05:04:02 +01:00
parser . add_argument ( " --pin-weight " , type = str2bool , nargs = " ? " , const = True , default = True , help = " FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20 %% ). " )
2023-04-01 02:18:05 +02:00
# DeepSpeed
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --deepspeed ' , action = ' store_true ' , help = ' Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. ' )
parser . add_argument ( ' --nvme-offload-dir ' , type = str , help = ' DeepSpeed: Directory to use for ZeRO-3 NVME offloading. ' )
parser . add_argument ( ' --local_rank ' , type = int , default = 0 , help = ' DeepSpeed: Optional argument for distributed setups. ' )
2023-04-01 02:18:05 +02:00
# RWKV
2023-03-07 00:12:54 +01:00
parser . add_argument ( ' --rwkv-strategy ' , type = str , default = None , help = ' RWKV: The strategy to use while loading the model. Examples: " cpu fp32 " , " cuda fp16 " , " cuda fp16i8 " . ' )
parser . add_argument ( ' --rwkv-cuda-on ' , action = ' store_true ' , help = ' RWKV: Compile the CUDA kernel for better performance. ' )
2023-04-01 02:18:05 +02:00
# Gradio
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --listen ' , action = ' store_true ' , help = ' Make the web UI reachable from your local network. ' )
2023-04-14 02:35:08 +02:00
parser . add_argument ( ' --listen-host ' , type = str , help = ' The hostname that the server will use. ' )
2023-02-23 16:05:25 +01:00
parser . add_argument ( ' --listen-port ' , type = int , help = ' The listening port that the server will use. ' )
parser . add_argument ( ' --share ' , action = ' store_true ' , help = ' Create a public URL. This is useful for running the web UI on Google Colab or similar. ' )
2023-03-13 16:44:18 +01:00
parser . add_argument ( ' --auto-launch ' , action = ' store_true ' , default = False , help = ' Open the web UI in the default browser upon launch. ' )
2023-05-24 01:39:26 +02:00
parser . add_argument ( " --gradio-auth " , type = str , help = ' set gradio authentication like " username:password " ; or comma-delimit multiple like " u1:p1,u2:p2,u3:p3 " ' , default = None )
2023-03-28 04:39:26 +02:00
parser . add_argument ( " --gradio-auth-path " , type = str , help = ' Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: " u1:p1,u2:p2,u3:p3 " ' , default = None )
2023-04-01 02:18:05 +02:00
2023-04-23 20:52:43 +02:00
# API
parser . add_argument ( ' --api ' , action = ' store_true ' , help = ' Enable the API extension. ' )
2023-05-16 01:44:16 +02:00
parser . add_argument ( ' --api-blocking-port ' , type = int , default = 5000 , help = ' The listening port for the blocking API. ' )
parser . add_argument ( ' --api-streaming-port ' , type = int , default = 5005 , help = ' The listening port for the streaming API. ' )
2023-04-23 20:52:43 +02:00
parser . add_argument ( ' --public-api ' , action = ' store_true ' , help = ' Create a public URL for the API using Cloudfare. ' )
2023-05-10 01:18:02 +02:00
# Multimodal
parser . add_argument ( ' --multimodal-pipeline ' , type = str , default = None , help = ' The multimodal pipeline to use. Examples: llava-7b, llava-13b. ' )
2023-04-23 20:52:43 +02:00
2023-02-23 16:05:25 +01:00
args = parser . parse_args ( )
2023-04-14 20:35:06 +02:00
args_defaults = parser . parse_args ( [ ] )
2023-03-14 11:56:31 +01:00
2023-04-05 16:49:59 +02:00
# Deprecation warnings for parameters that have been renamed
2023-04-12 22:09:56 +02:00
deprecated_dict = { }
2023-03-26 05:11:33 +02:00
for k in deprecated_dict :
2023-04-16 06:36:50 +02:00
if getattr ( args , k ) != deprecated_dict [ k ] [ 1 ] :
2023-05-22 03:42:34 +02:00
logger . warning ( f " -- { k } is deprecated and will be removed. Use -- { deprecated_dict [ k ] [ 0 ] } instead. " )
2023-04-16 06:36:50 +02:00
setattr ( args , deprecated_dict [ k ] [ 0 ] , getattr ( args , k ) )
2023-04-02 01:14:43 +02:00
2023-04-17 00:15:03 +02:00
# Security warnings
if args . trust_remote_code :
2023-05-22 03:42:34 +02:00
logger . warning ( " trust_remote_code is enabled. This is dangerous. " )
2023-04-18 00:34:28 +02:00
if args . share :
2023-05-29 03:48:20 +02:00
logger . warning ( " The gradio \" share link \" feature uses a proprietary executable to create a reverse tunnel. Use it with care. " )
2023-04-17 00:15:03 +02:00
2023-05-10 01:18:02 +02:00
def add_extension ( name ) :
2023-04-23 20:52:43 +02:00
if args . extensions is None :
2023-05-10 01:18:02 +02:00
args . extensions = [ name ]
2023-04-23 20:52:43 +02:00
elif ' api ' not in args . extensions :
2023-05-10 01:18:02 +02:00
args . extensions . append ( name )
# Activating the API extension
if args . api or args . public_api :
add_extension ( ' api ' )
# Activating the multimodal extension
if args . multimodal_pipeline is not None :
add_extension ( ' multimodal ' )
2023-04-23 20:52:43 +02:00
2023-04-07 05:15:45 +02:00
2023-04-02 01:14:43 +02:00
def is_chat ( ) :
2023-04-05 16:49:59 +02:00
return args . chat
2023-04-14 16:07:28 +02:00
2023-05-12 11:09:45 +02:00
# Loading model-specific settings
2023-04-14 16:07:28 +02:00
with Path ( f ' { args . model_dir } /config.yaml ' ) as p :
if p . exists ( ) :
model_config = yaml . safe_load ( open ( p , ' r ' ) . read ( ) )
else :
model_config = { }
# Applying user-defined model settings
with Path ( f ' { args . model_dir } /config-user.yaml ' ) as p :
if p . exists ( ) :
user_config = yaml . safe_load ( open ( p , ' r ' ) . read ( ) )
for k in user_config :
if k in model_config :
model_config [ k ] . update ( user_config [ k ] )
else :
model_config [ k ] = user_config [ k ]
2023-05-12 11:09:45 +02:00
model_config = OrderedDict ( model_config )