diff --git a/extensions/example/script.py b/extensions/example/script.py index 71c7e4ee..fd303dcb 100644 --- a/extensions/example/script.py +++ b/extensions/example/script.py @@ -1,5 +1,6 @@ params = { "input suffix": " *I say as I make a funny face*", + "bot prefix": " *I speak in a cute way*", } def input_modifier(string): @@ -16,3 +17,12 @@ def output_modifier(string): """ return string + +def bot_prefix_modifier(string): + """ + This function is only applied in chat mode. It modifies + the prefix text for the Bot and can be used to bias its + behavior. + """ + + return string + params["bot prefix"] diff --git a/extensions/google_translate/script.py b/extensions/google_translate/script.py index c7f56d6c..78bf7c08 100644 --- a/extensions/google_translate/script.py +++ b/extensions/google_translate/script.py @@ -20,3 +20,12 @@ def output_modifier(string): """ return translator.translate(string, src="en", dest=params['language string']).text + +def bot_prefix_modifier(string): + """ + This function is only applied in chat mode. It modifies + the prefix text for the Bot and can be used to bias its + behavior. + """ + + return string diff --git a/server.py b/server.py index d1fa9fbb..33644beb 100644 --- a/server.py +++ b/server.py @@ -235,10 +235,12 @@ def apply_extensions(text, typ): for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): if extension_state[ext][0] == True: ext_string = f"extensions.{ext}.script" - if typ == "input": + if typ == "input" and hasattr(eval(ext_string), "input_modifier"): text = eval(f"{ext_string}.input_modifier(text)") - else: + elif typ == "output" and hasattr(eval(ext_string), "output_modifier"): text = eval(f"{ext_string}.output_modifier(text)") + elif typ == "bot_prefix" and hasattr(eval(ext_string), "bot_prefix_modifier"): + text = eval(f"{ext_string}.bot_prefix_modifier(text)") return text def update_extensions_parameters(*kwargs): @@ -274,7 +276,6 @@ def create_extensions_block(): btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], []) return extensions_ui_elements, btn_extensions - def get_available_models(): return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower) @@ -353,7 +354,7 @@ if args.chat or args.cai_chat: if history_size != 0 and count >= history_size: break rows.append(f"{name1}: {text}\n") - rows.append(f"{name2}:") + rows.append(apply_extensions(f"{name2}:", "bot_prefix")) while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens: rows.pop(1) @@ -376,7 +377,7 @@ if args.chat or args.cai_chat: idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)] idx = idx[len(previous_idx)-1] - reply = reply[idx + len(f"\n{name2}:"):] + reply = reply[idx + 1 + len(apply_extensions(f"{name2}:", "bot_prefix")):] if check: reply = reply.split('\n')[0].strip() else: