mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
memoize load_character to speed up the chat API
This commit is contained in:
parent
8b9ba3d7b4
commit
4d94a111d4
@ -4,7 +4,7 @@ from threading import Thread
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.chat import load_character
|
from modules.chat import load_character_memoized
|
||||||
|
|
||||||
|
|
||||||
def build_parameters(body, chat=False):
|
def build_parameters(body, chat=False):
|
||||||
@ -41,8 +41,8 @@ def build_parameters(body, chat=False):
|
|||||||
if chat:
|
if chat:
|
||||||
character = body.get('character')
|
character = body.get('character')
|
||||||
instruction_template = body.get('instruction_template')
|
instruction_template = body.get('instruction_template')
|
||||||
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
|
name1, name2, _, greeting, context, _ = load_character_memoized(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
|
||||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True)
|
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
|
||||||
generate_params.update({
|
generate_params.update({
|
||||||
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
|
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
|
||||||
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
|
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import ast
|
import ast
|
||||||
import base64
|
import base64
|
||||||
import copy
|
import copy
|
||||||
|
import functools
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@ -568,6 +569,11 @@ def load_character(character, name1, name2, instruct=False):
|
|||||||
return name1, name2, picture, greeting, context, repr(turn_template)[1:-1]
|
return name1, name2, picture, greeting, context, repr(turn_template)[1:-1]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def load_character_memoized(character, name1, name2, instruct=False):
|
||||||
|
return load_character(character, name1, name2, instruct=instruct)
|
||||||
|
|
||||||
|
|
||||||
def upload_character(json_file, img, tavern=False):
|
def upload_character(json_file, img, tavern=False):
|
||||||
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||||
data = json.loads(json_file)
|
data = json.loads(json_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user