2024-02-25 18:24:52 +01:00
import importlib
2023-03-24 20:51:27 +01:00
import traceback
2023-04-24 01:32:22 +02:00
from functools import partial
2023-07-25 23:49:56 +02:00
from inspect import signature
2023-03-24 20:51:27 +01:00
2023-03-15 19:11:16 +01:00
import gradio as gr
2023-02-23 16:05:25 +01:00
import extensions
2023-02-23 18:41:42 +01:00
import modules . shared as shared
2023-05-22 03:42:34 +02:00
from modules . logging_colors import logger
2023-02-23 16:05:25 +01:00
2023-02-24 14:01:21 +01:00
state = { }
2023-02-23 16:05:25 +01:00
available_extensions = [ ]
2023-03-29 04:27:02 +02:00
setup_called = set ( )
2023-02-23 16:05:25 +01:00
2023-04-07 05:15:45 +02:00
2023-04-25 05:23:11 +02:00
def apply_settings ( extension , name ) :
if not hasattr ( extension , ' params ' ) :
return
for param in extension . params :
_id = f " { name } - { param } "
2024-01-10 12:48:30 +01:00
shared . default_settings [ _id ] = extension . params [ param ]
if _id in shared . settings :
extension . params [ param ] = shared . settings [ _id ]
2023-04-25 05:23:11 +02:00
2023-02-23 18:49:02 +01:00
def load_extensions ( ) :
2023-04-07 17:20:57 +02:00
global state , setup_called
2023-09-17 22:35:43 +02:00
state = { }
2023-02-24 14:01:21 +01:00
for i , name in enumerate ( shared . args . extensions ) :
if name in available_extensions :
2023-04-23 20:52:43 +02:00
if name != ' api ' :
2023-12-20 05:54:32 +01:00
logger . info ( f ' Loading the extension " { name } " ' )
2023-03-16 03:29:56 +01:00
try :
2023-12-20 05:20:45 +01:00
try :
2024-02-25 18:24:52 +01:00
extension = importlib . import_module ( f " extensions. { name } .script " )
2023-12-20 05:20:45 +01:00
except ModuleNotFoundError :
logger . error ( f " Could not import the requirements for ' { name } ' . Make sure to install the requirements for the extension. \n \n Linux / Mac: \n \n pip install -r extensions/ { name } /requirements.txt --upgrade \n \n Windows: \n \n pip install -r extensions \\ { name } \\ requirements.txt --upgrade \n \n If you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS. " )
raise
2024-01-10 12:48:30 +01:00
# Only run setup() and apply settings from settings.yaml once
if extension not in setup_called :
apply_settings ( extension , name )
if hasattr ( extension , " setup " ) :
extension . setup ( )
2023-04-07 17:20:57 +02:00
setup_called . add ( extension )
2023-04-25 05:10:21 +02:00
2023-03-16 03:29:56 +01:00
state [ name ] = [ True , i ]
except :
2023-05-22 03:42:34 +02:00
logger . error ( f ' Failed to load the extension " { name } " . ' )
2023-03-24 20:51:27 +01:00
traceback . print_exc ( )
2023-02-23 18:49:02 +01:00
2023-04-07 05:15:45 +02:00
2023-04-07 17:20:57 +02:00
# This iterator returns the extensions in the order specified in the command-line
2023-02-24 14:01:21 +01:00
def iterator ( ) :
2023-04-07 05:15:45 +02:00
for name in sorted ( state , key = lambda x : state [ x ] [ 1 ] ) :
2023-04-07 05:52:02 +02:00
if state [ name ] [ 0 ] :
2023-04-16 06:36:50 +02:00
yield getattr ( extensions , name ) . script , name
2023-02-24 14:01:21 +01:00
2023-04-07 05:15:45 +02:00
2023-04-07 17:20:57 +02:00
# Extension functions that map string -> string
2023-08-13 06:12:15 +02:00
def _apply_string_extensions ( function_name , text , state , is_chat = False ) :
2023-02-24 14:01:21 +01:00
for extension , _ in iterator ( ) :
2023-04-24 01:32:22 +02:00
if hasattr ( extension , function_name ) :
2023-07-04 05:03:30 +02:00
func = getattr ( extension , function_name )
2023-08-13 06:12:15 +02:00
# Handle old extensions without the 'state' arg or
# the 'is_chat' kwarg
count = 0
has_chat = False
for k in signature ( func ) . parameters :
if k == ' is_chat ' :
has_chat = True
else :
count + = 1
if count == 2 :
args = [ text , state ]
else :
args = [ text ]
if has_chat :
kwargs = { ' is_chat ' : is_chat }
2023-07-04 05:03:30 +02:00
else :
2023-08-13 06:12:15 +02:00
kwargs = { }
text = func ( * args , * * kwargs )
2023-04-25 05:10:21 +02:00
2023-02-23 16:05:25 +01:00
return text
2023-04-07 05:15:45 +02:00
2023-07-25 23:49:56 +02:00
# Extension functions that map string -> string
def _apply_chat_input_extensions ( text , visible_text , state ) :
2023-04-24 01:32:22 +02:00
for extension , _ in iterator ( ) :
2023-07-25 23:49:56 +02:00
if hasattr ( extension , ' chat_input_modifier ' ) :
text , visible_text = extension . chat_input_modifier ( text , visible_text , state )
2023-04-25 05:10:21 +02:00
2023-04-24 01:32:22 +02:00
return text , visible_text
2023-05-10 01:18:02 +02:00
# custom_generate_chat_prompt handling - currently only the first one will work
2023-04-24 01:32:22 +02:00
def _apply_custom_generate_chat_prompt ( text , state , * * kwargs ) :
for extension , _ in iterator ( ) :
2023-05-10 01:18:02 +02:00
if hasattr ( extension , ' custom_generate_chat_prompt ' ) :
2023-05-10 16:29:59 +02:00
return extension . custom_generate_chat_prompt ( text , state , * * kwargs )
2023-04-25 05:10:21 +02:00
2023-04-24 01:32:22 +02:00
return None
2023-05-05 23:53:03 +02:00
# Extension that modifies the input parameters before they are used
def _apply_state_modifier_extensions ( state ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , " state_modifier " ) :
state = getattr ( extension , " state_modifier " ) ( state )
return state
2023-05-10 03:49:39 +02:00
2023-05-05 23:53:03 +02:00
2023-05-21 18:24:54 +02:00
# Extension that modifies the chat history before it is used
def _apply_history_modifier_extensions ( history ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , " history_modifier " ) :
history = getattr ( extension , " history_modifier " ) ( history )
return history
2023-07-13 22:22:41 +02:00
# Extension functions that override the default tokenizer output - The order of execution is not defined
2023-04-24 01:32:22 +02:00
def _apply_tokenizer_extensions ( function_name , state , prompt , input_ids , input_embeds ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , function_name ) :
2023-07-13 22:22:41 +02:00
prompt , input_ids , input_embeds = getattr ( extension , function_name ) ( state , prompt , input_ids , input_embeds )
2023-04-25 05:10:21 +02:00
2023-04-24 01:32:22 +02:00
return prompt , input_ids , input_embeds
2023-07-13 22:22:41 +02:00
# Allow extensions to add their own logits processors to the stack being run.
# Each extension would call `processor_list.append({their LogitsProcessor}())`.
def _apply_logits_processor_extensions ( function_name , processor_list , input_ids ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , function_name ) :
2023-07-25 23:49:56 +02:00
result = getattr ( extension , function_name ) ( processor_list , input_ids )
if type ( result ) is list :
processor_list = result
return processor_list
2023-07-13 22:22:41 +02:00
2023-05-10 01:18:02 +02:00
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
# currently only the first one will work
def _apply_custom_tokenized_length ( prompt ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , ' custom_tokenized_length ' ) :
return getattr ( extension , ' custom_tokenized_length ' ) ( prompt )
2023-05-10 03:49:39 +02:00
2023-05-10 01:18:02 +02:00
return None
# Custom generate reply handling - currently only the first one will work
2023-05-05 23:53:03 +02:00
def _apply_custom_generate_reply ( ) :
for extension , _ in iterator ( ) :
if hasattr ( extension , ' custom_generate_reply ' ) :
return getattr ( extension , ' custom_generate_reply ' )
return None
2023-05-17 05:03:39 +02:00
def _apply_custom_css ( ) :
all_css = ' '
for extension , _ in iterator ( ) :
if hasattr ( extension , ' custom_css ' ) :
all_css + = getattr ( extension , ' custom_css ' ) ( )
return all_css
def _apply_custom_js ( ) :
all_js = ' '
for extension , _ in iterator ( ) :
if hasattr ( extension , ' custom_js ' ) :
all_js + = getattr ( extension , ' custom_js ' ) ( )
return all_js
2023-05-17 06:25:01 +02:00
def create_extensions_block ( ) :
to_display = [ ]
for extension , name in iterator ( ) :
if hasattr ( extension , " ui " ) and not ( hasattr ( extension , ' params ' ) and extension . params . get ( ' is_tab ' , False ) ) :
to_display . append ( ( extension , name ) )
# Creating the extension ui elements
if len ( to_display ) > 0 :
with gr . Column ( elem_id = " extensions " ) :
for row in to_display :
2023-08-13 06:12:15 +02:00
extension , _ = row
2023-05-17 06:25:01 +02:00
extension . ui ( )
def create_extensions_tabs ( ) :
for extension , name in iterator ( ) :
if hasattr ( extension , " ui " ) and ( hasattr ( extension , ' params ' ) and extension . params . get ( ' is_tab ' , False ) ) :
display_name = getattr ( extension , ' params ' , { } ) . get ( ' display_name ' , name )
with gr . Tab ( display_name , elem_classes = " extension-tab " ) :
extension . ui ( )
2023-04-24 01:32:22 +02:00
EXTENSION_MAP = {
" input " : partial ( _apply_string_extensions , " input_modifier " ) ,
" output " : partial ( _apply_string_extensions , " output_modifier " ) ,
2023-07-25 23:49:56 +02:00
" chat_input " : _apply_chat_input_extensions ,
2023-05-05 23:53:03 +02:00
" state " : _apply_state_modifier_extensions ,
2023-05-21 18:24:54 +02:00
" history " : _apply_history_modifier_extensions ,
2023-04-24 01:32:22 +02:00
" bot_prefix " : partial ( _apply_string_extensions , " bot_prefix_modifier " ) ,
" tokenizer " : partial ( _apply_tokenizer_extensions , " tokenizer_modifier " ) ,
2023-07-13 22:22:41 +02:00
' logits_processor ' : partial ( _apply_logits_processor_extensions , ' logits_processor_modifier ' ) ,
2023-05-05 23:53:03 +02:00
" custom_generate_chat_prompt " : _apply_custom_generate_chat_prompt ,
2023-05-10 01:18:02 +02:00
" custom_generate_reply " : _apply_custom_generate_reply ,
2023-05-17 05:03:39 +02:00
" tokenized_length " : _apply_custom_tokenized_length ,
" css " : _apply_custom_css ,
" js " : _apply_custom_js
2023-04-24 01:32:22 +02:00
}
def apply_extensions ( typ , * args , * * kwargs ) :
if typ not in EXTENSION_MAP :
raise ValueError ( f " Invalid extension type { typ } " )
2023-04-25 05:10:21 +02:00
2023-04-24 01:32:22 +02:00
return EXTENSION_MAP [ typ ] ( * args , * * kwargs )