memoize load_character to speed up the chat API

This commit is contained in:
oobabooga 2023-05-23 00:50:58 -03:00
parent 8b9ba3d7b4
commit 4d94a111d4
2 changed files with 9 additions and 3 deletions

View File

@ -4,7 +4,7 @@ from threading import Thread
from typing import Callable, Optional from typing import Callable, Optional
from modules import shared from modules import shared
from modules.chat import load_character from modules.chat import load_character_memoized
def build_parameters(body, chat=False): def build_parameters(body, chat=False):
@ -41,8 +41,8 @@ def build_parameters(body, chat=False):
if chat: if chat:
character = body.get('character') character = body.get('character')
instruction_template = body.get('instruction_template') instruction_template = body.get('instruction_template')
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False) name1, name2, _, greeting, context, _ = load_character_memoized(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True) name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
generate_params.update({ generate_params.update({
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])), 'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])), 'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),

View File

@ -1,6 +1,7 @@
import ast import ast
import base64 import base64
import copy import copy
import functools
import io import io
import json import json
import re import re
@ -568,6 +569,11 @@ def load_character(character, name1, name2, instruct=False):
return name1, name2, picture, greeting, context, repr(turn_template)[1:-1] return name1, name2, picture, greeting, context, repr(turn_template)[1:-1]
@functools.cache
def load_character_memoized(character, name1, name2, instruct=False):
return load_character(character, name1, name2, instruct=instruct)
def upload_character(json_file, img, tavern=False): def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8') json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
data = json.loads(json_file) data = json.loads(json_file)