2023-09-26 21:30:19 -03:00

139 lines
4.7 KiB
Python

"""
This module is responsible for modifying the chat prompt and history.
"""
import json
import re
import extensions.superboogav2.parameters as parameters
from modules import chat
from modules.text_generation import get_encoded_length
from modules.logging_colors import logger
from extensions.superboogav2.utils import create_context_text, create_metadata_source
from .data_processor import process_and_add_to_collector
from .chromadb import ChromaCollector
CHAT_METADATA = create_metadata_source('automatic-chat-insert')
INSTRUCT_MODE = 'instruct'
CHAT_INSTRUCT_MODE = 'chat-instruct'
def _is_instruct_mode(state: dict):
mode = state.get('mode')
return mode == INSTRUCT_MODE or mode == CHAT_INSTRUCT_MODE
def _remove_tag_if_necessary(user_input: str):
if not parameters.get_is_manual():
return user_input
return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input)
def _should_query(input: str):
if not parameters.get_is_manual():
return True
if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE):
return True
return False
def _format_single_exchange(name, text):
if re.search(r':\s*$', name):
return '{} {}\n'.format(name, text)
else:
return '{}: {}\n'.format(name, text)
def _get_names(state: dict):
if _is_instruct_mode(state):
user_name = state['name1_instruct']
bot_name = state['name2_instruct']
else:
user_name = state['name1']
bot_name = state['name2']
if not user_name:
user_name = 'User'
if not bot_name:
bot_name = 'Assistant'
return user_name, bot_name
def _concatinate_history(history: dict, state: dict):
full_history_text = ''
user_name, bot_name = _get_names(state)
# Grab the internal history.
internal_history = history['internal']
assert isinstance(internal_history, list)
# Iterate through the history.
for exchange in internal_history:
assert isinstance(exchange, list)
if len(exchange) >= 1:
full_history_text += _format_single_exchange(user_name, exchange[0])
if len(exchange) >= 2:
full_history_text += _format_single_exchange(bot_name, exchange[1])
return full_history_text[:-1] # Remove the last new line.
def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
num_context_tokens = get_encoded_length(context_text)
names = _get_names(state)[::-1]
history_tokens = 0
replace_position = None
for i, messages in enumerate(reversed(history['internal'])):
for j, message in enumerate(reversed(messages)):
num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message))
# TODO: This is an extremely naive solution. A more robust implementation must be made.
if history_tokens + num_context_tokens <= max_len:
# This message can be replaced
replace_position = (i, j)
history_tokens += num_message_tokens
if replace_position is None:
logger.warn("The provided context_text is too long to replace any message in the history.")
else:
# replace the message at replace_position with context_text
i, j = replace_position
history['internal'][-i-1][-j-1] = context_text
def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
if parameters.get_add_chat_to_data():
# Get the whole history as one string
history_as_text = _concatinate_history(kwargs['history'], state)
if history_as_text:
# Delete all documents that were auto-inserted
collector.delete(ids_to_delete=None, where=CHAT_METADATA)
# Insert the processed history
process_and_add_to_collector(history_as_text, collector, False, CHAT_METADATA)
if _should_query(user_input):
user_input = _remove_tag_if_necessary(user_input)
results = collector.get_sorted_by_dist(user_input, n_results=parameters.get_chunk_count(), max_token_count=int(parameters.get_max_token_count()))
# Check if the strategy is to modify the last message. If so, prepend or append to the user query.
if parameters.get_injection_strategy() == parameters.APPEND_TO_LAST:
user_input = user_input + create_context_text(results)
elif parameters.get_injection_strategy() == parameters.PREPEND_TO_LAST:
user_input = create_context_text(results) + user_input
elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT:
_hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state)
return chat.generate_chat_prompt(user_input, state, **kwargs)