From 4d94a111d498aa99e162f50edc6cfb129897f220 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 23 May 2023 00:50:58 -0300 Subject: [PATCH] memoize load_character to speed up the chat API --- extensions/api/util.py | 6 +++--- modules/chat.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/extensions/api/util.py b/extensions/api/util.py index bd86f8d1..596caee2 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -4,7 +4,7 @@ from threading import Thread from typing import Callable, Optional from modules import shared -from modules.chat import load_character +from modules.chat import load_character_memoized def build_parameters(body, chat=False): @@ -41,8 +41,8 @@ def build_parameters(body, chat=False): if chat: character = body.get('character') instruction_template = body.get('instruction_template') - name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False) - name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True) + name1, name2, _, greeting, context, _ = load_character_memoized(character, shared.settings['name1'], shared.settings['name2'], instruct=False) + name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True) generate_params.update({ 'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])), 'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])), diff --git a/modules/chat.py b/modules/chat.py index a4d10f2d..be5eb9a7 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -1,6 +1,7 @@ import ast import base64 import copy +import functools import io import json import re @@ -568,6 +569,11 @@ def load_character(character, name1, name2, instruct=False): return name1, name2, picture, greeting, context, repr(turn_template)[1:-1] +@functools.cache +def load_character_memoized(character, name1, name2, instruct=False): + return load_character(character, name1, name2, instruct=instruct) + + def upload_character(json_file, img, tavern=False): json_file = json_file if type(json_file) == str else json_file.decode('utf-8') data = json.loads(json_file)