2024-02-25 09:24:52 -08:00
import importlib
2023-03-24 16:51:27 -03:00
import traceback
2023-04-24 01:32:22 +02:00
from functools import partial
2023-07-25 18:49:56 -03:00
from inspect import signature
2023-03-24 16:51:27 -03:00
2023-03-15 15:11:16 -03:00
import gradio as gr
2023-02-23 12:05:25 -03:00
import extensions
2023-02-23 14:41:42 -03:00
import modules . shared as shared
2023-05-21 22:42:34 -03:00
from modules . logging_colors import logger
2023-02-23 12:05:25 -03:00
2023-02-24 10:01:21 -03:00
state = { }
2023-02-23 12:05:25 -03:00
available_extensions = [ ]
2023-03-28 23:27:02 -03:00
setup_called = set ( )
2023-02-23 12:05:25 -03:00
2023-04-07 00:15:45 -03:00
2023-04-24 23:23:11 -04:00
def apply_settings ( extension , name ) :
if not hasattr ( extension , ' params ' ) :
return
for param in extension . params :
_id = f " { name } - { param } "
2024-01-10 03:48:30 -08:00
shared . default_settings [ _id ] = extension . params [ param ]
if _id in shared . settings :
extension . params [ param ] = shared . settings [ _id ]
2023-04-24 23:23:11 -04:00
2023-02-23 14:49:02 -03:00
def load_extensions ( ) :
2023-04-07 12:20:57 -03:00
global state , setup_called
2023-09-18 04:35:43 +08:00
state = { }
2023-02-24 10:01:21 -03:00
for i , name in enumerate ( shared . args . extensions ) :
if name in available_extensions :
2023-04-23 11:52:43 -07:00
if name != ' api ' :
2023-12-19 20:54:32 -08:00
logger . info ( f ' Loading the extension " { name } " ' )
2023-03-15 23:29:56 -03:00
try :
2023-12-19 20:20:45 -08:00
try :
2024-02-25 09:24:52 -08:00
extension = importlib . import_module ( f " extensions. { name } .script " )
2023-12-19 20:20:45 -08:00
except ModuleNotFoundError :
2024-03-04 19:26:24 -08:00
logger . error ( f " Could not import the requirements for ' { name } ' . Make sure to install the requirements for the extension. \n \n * To install requirements for all available extensions, launch the \n update_wizard script for your OS and choose the B option. \n \n * To install the requirements for this extension alone, launch the \n cmd script for your OS and paste the following command in the \n terminal window that appears: \n \n Linux / Mac: \n \n pip install -r extensions/ { name } /requirements.txt --upgrade \n \n Windows: \n \n pip install -r extensions \\ { name } \\ requirements.txt --upgrade \n " )
2023-12-19 20:20:45 -08:00
raise
2024-01-10 03:48:30 -08:00
# Only run setup() and apply settings from settings.yaml once
if extension not in setup_called :
apply_settings ( extension , name )
if hasattr ( extension , " setup " ) :
extension . setup ( )
2023-04-07 12:20:57 -03:00
setup_called . add ( extension )
2023-04-25 00:10:21 -03:00
2023-03-15 23:29:56 -03:00
state [ name ] = [ True , i ]
except :
2023-05-21 22:42:34 -03:00
logger . error ( f ' Failed to load the extension " { name } " . ' )
2023-03-24 16:51:27 -03:00
traceback . print_exc ( )
2023-02-23 14:49:02 -03:00
2023-04-07 00:15:45 -03:00
2023-04-07 12:20:57 -03:00
# This iterator returns the extensions in the order specified in the command-line
2023-02-24 10:01:21 -03:00
def iterator ( ) :
2023-04-07 00:15:45 -03:00
for name in sorted ( state , key = lambda x : state [ x ] [ 1 ] ) :
2023-04-07 00:52:02 -03:00
if state [ name ] [ 0 ] :
2023-04-16 05:36:50 +01:00
yield getattr ( extensions , name ) . script , name
2023-02-24 10:01:21 -03:00
2023-04-07 00:15:45 -03:00
2023-04-07 12:20:57 -03:00
# Extension functions that map string -> string
2023-08-13 01:12:15 -03:00
def _apply_string_extensions ( function_name , text , state , is_chat = False ) :
2023-02-24 10:01:21 -03:00
for extension , _ in iterator ( ) :
2023-04-24 01:32:22 +02:00
if hasattr ( extension , function_name ) :
2023-07-04 00:03:30 -03:00
func = getattr ( extension , function_name )
2023-08-13 01:12:15 -03:00
# Handle old extensions without the 'state' arg or
# the 'is_chat' kwarg
count = 0
has_chat = False
for k in signature ( func ) . parameters :
if k == ' is_chat ' :
has_chat = True
else :
count + = 1
if count == 2 :
args = [ text , state ]
else :
args = [ text ]
if has_chat :
kwargs = { ' is_chat ' : is_chat }
2023-07-04 00:03:30 -03:00
else :
2023-08-13 01:12:15 -03:00
kwargs = { }
text = func ( * args , * * kwargs )
2023-04-25 00:10:21 -03:00
2023-02-23 12:05:25 -03:00
return text
2023-04-07 00:15:45 -03:00
2023-07-25 18:49:56 -03:00
# Extension functions that map string -> string
def _apply_chat_input_extensions ( text , visible_text , state ) :
2023-04-24 01:32:22 +02:00
for extension , _ in iterator ( ) :
2023-07-25 18:49:56 -03:00
if hasattr ( extension , ' chat_input_modifier ' ) :
text , visible_text = extension . chat_input_modifier ( text , visible_text , state )
2023-04-25 00:10:21 -03:00
2023-04-24 01:32:22 +02:00
return text , visible_text
2023-05-10 01:18:02 +02:00
# custom_generate_chat_prompt handling - currently only the first one will work
2023-04-24 01:32:22 +02:00
def _apply_custom_generate_chat_prompt ( text , state , * * kwargs ) :
for extension , _ in iterator ( ) :
2023-05-10 01:18:02 +02:00
if hasattr ( extension , ' custom_generate_chat_prompt ' ) :
2023-05-10 09:29:59 -05:00
return extension . custom_generate_chat_prompt ( text , state , * * kwargs )
2023-04-25 00:10:21 -03:00
2023-04-24 01:32:22 +02:00
return None
2023-05-05 18:53:03 -03:00
# Extension that modifies the input parameters before they are used
def _apply_state_modifier_extensions ( state ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , " state_modifier " ) :
state = getattr ( extension , " state_modifier " ) ( state )
return state
2023-05-09 22:49:39 -03:00
2023-05-05 18:53:03 -03:00
2023-05-21 13:24:54 -03:00
# Extension that modifies the chat history before it is used
def _apply_history_modifier_extensions ( history ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , " history_modifier " ) :
history = getattr ( extension , " history_modifier " ) ( history )
return history
2023-07-13 13:22:41 -07:00
# Extension functions that override the default tokenizer output - The order of execution is not defined
2023-04-24 01:32:22 +02:00
def _apply_tokenizer_extensions ( function_name , state , prompt , input_ids , input_embeds ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , function_name ) :
2023-07-13 13:22:41 -07:00
prompt , input_ids , input_embeds = getattr ( extension , function_name ) ( state , prompt , input_ids , input_embeds )
2023-04-25 00:10:21 -03:00
2023-04-24 01:32:22 +02:00
return prompt , input_ids , input_embeds
2023-07-13 13:22:41 -07:00
# Allow extensions to add their own logits processors to the stack being run.
# Each extension would call `processor_list.append({their LogitsProcessor}())`.
def _apply_logits_processor_extensions ( function_name , processor_list , input_ids ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , function_name ) :
2023-07-25 18:49:56 -03:00
result = getattr ( extension , function_name ) ( processor_list , input_ids )
if type ( result ) is list :
processor_list = result
return processor_list
2023-07-13 13:22:41 -07:00
2023-05-10 01:18:02 +02:00
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
# currently only the first one will work
def _apply_custom_tokenized_length ( prompt ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , ' custom_tokenized_length ' ) :
return getattr ( extension , ' custom_tokenized_length ' ) ( prompt )
2023-05-09 22:49:39 -03:00
2023-05-10 01:18:02 +02:00
return None
# Custom generate reply handling - currently only the first one will work
2023-05-05 18:53:03 -03:00
def _apply_custom_generate_reply ( ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , ' custom_generate_reply ' ) :
return getattr ( extension , ' custom_generate_reply ' )
return None
2023-05-17 00:03:39 -03:00
def _apply_custom_css ( ) :
all_css = ' '
for extension , _ in iterator ( ) :
if hasattr ( extension , ' custom_css ' ) :
all_css + = getattr ( extension , ' custom_css ' ) ( )
return all_css
def _apply_custom_js ( ) :
all_js = ' '
for extension , _ in iterator ( ) :
if hasattr ( extension , ' custom_js ' ) :
all_js + = getattr ( extension , ' custom_js ' ) ( )
return all_js
2023-05-17 01:25:01 -03:00
def create_extensions_block ( ) :
to_display = [ ]
for extension , name in iterator ( ) :
if hasattr ( extension , " ui " ) and not ( hasattr ( extension , ' params ' ) and extension . params . get ( ' is_tab ' , False ) ) :
to_display . append ( ( extension , name ) )
# Creating the extension ui elements
if len ( to_display ) > 0 :
with gr . Column ( elem_id = " extensions " ) :
for row in to_display :
2023-08-13 01:12:15 -03:00
extension , _ = row
2023-05-17 01:25:01 -03:00
extension . ui ( )
def create_extensions_tabs ( ) :
for extension , name in iterator ( ) :
if hasattr ( extension , " ui " ) and ( hasattr ( extension , ' params ' ) and extension . params . get ( ' is_tab ' , False ) ) :
display_name = getattr ( extension , ' params ' , { } ) . get ( ' display_name ' , name )
with gr . Tab ( display_name , elem_classes = " extension-tab " ) :
extension . ui ( )
2023-04-24 01:32:22 +02:00
EXTENSION_MAP = {
" input " : partial ( _apply_string_extensions , " input_modifier " ) ,
" output " : partial ( _apply_string_extensions , " output_modifier " ) ,
2023-07-25 18:49:56 -03:00
" chat_input " : _apply_chat_input_extensions ,
2023-05-05 18:53:03 -03:00
" state " : _apply_state_modifier_extensions ,
2023-05-21 13:24:54 -03:00
" history " : _apply_history_modifier_extensions ,
2023-04-24 01:32:22 +02:00
" bot_prefix " : partial ( _apply_string_extensions , " bot_prefix_modifier " ) ,
" tokenizer " : partial ( _apply_tokenizer_extensions , " tokenizer_modifier " ) ,
2023-07-13 13:22:41 -07:00
' logits_processor ' : partial ( _apply_logits_processor_extensions , ' logits_processor_modifier ' ) ,
2023-05-05 18:53:03 -03:00
" custom_generate_chat_prompt " : _apply_custom_generate_chat_prompt ,
2023-05-10 01:18:02 +02:00
" custom_generate_reply " : _apply_custom_generate_reply ,
2023-05-17 00:03:39 -03:00
" tokenized_length " : _apply_custom_tokenized_length ,
" css " : _apply_custom_css ,
" js " : _apply_custom_js
2023-04-24 01:32:22 +02:00
}
def apply_extensions ( typ , * args , * * kwargs ) :
if typ not in EXTENSION_MAP :
raise ValueError ( f " Invalid extension type { typ } " )
2023-04-25 00:10:21 -03:00
2023-04-24 01:32:22 +02:00
return EXTENSION_MAP [ typ ] ( * args , * * kwargs )