diff --git a/modules/extensions.py b/modules/extensions.py index 92d86772..547fe3f4 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -23,12 +23,14 @@ def load_extensions(): if extension not in setup_called and hasattr(extension, "setup"): setup_called.add(extension) extension.setup() + state[name] = [True, i] if name != 'api': print('Ok.') except: if name != 'api': print('Fail.') + traceback.print_exc() @@ -44,6 +46,7 @@ def _apply_string_extensions(function_name, text): for extension, _ in iterator(): if hasattr(extension, function_name): text = getattr(extension, function_name)(text) + return text @@ -56,6 +59,7 @@ def _apply_input_hijack(text, visible_text): text, visible_text = extension.input_hijack['value'](text, visible_text) else: text, visible_text = extension.input_hijack['value'] + return text, visible_text @@ -65,8 +69,10 @@ def _apply_custom_generate_chat_prompt(text, state, **kwargs): for extension, _ in iterator(): if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): custom_generate_chat_prompt = extension.custom_generate_chat_prompt + if custom_generate_chat_prompt is not None: return custom_generate_chat_prompt(text, state, **kwargs) + return None @@ -75,6 +81,7 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e for extension, _ in iterator(): if hasattr(extension, function_name): prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds) + return prompt, input_ids, input_embeds @@ -91,6 +98,7 @@ EXTENSION_MAP = { def apply_extensions(typ, *args, **kwargs): if typ not in EXTENSION_MAP: raise ValueError(f"Invalid extension type {typ}") + return EXTENSION_MAP[typ](*args, **kwargs)