Generalize superbooga to chat mode

This commit is contained in:
oobabooga 2023-05-07 15:01:14 -03:00
parent ec1cda0e1f
commit 6b67cb6611
2 changed files with 65 additions and 17 deletions

View File

@ -8,9 +8,10 @@ import posthog
import torch import torch
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from chromadb.config import Settings from chromadb.config import Settings
from modules import shared
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from modules import chat, shared
print('Intercepting all calls to posthog :)') print('Intercepting all calls to posthog :)')
posthog.capture = lambda *args, **kwargs: None posthog.capture = lambda *args, **kwargs: None
@ -53,6 +54,10 @@ class ChromaCollector(Collecter):
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0] result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
return result return result
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x : int(x[2:]), result))
def clear(self): def clear(self):
self.collection.delete(ids=self.ids) self.collection.delete(ids=self.ids)
@ -68,18 +73,24 @@ collector = ChromaCollector(embedder)
chunk_count = 5 chunk_count = 5
def feed_data_into_collector(corpus, chunk_len): def add_chunks_to_collector(chunks):
global collector global collector
chunk_len = int(chunk_len) collector.clear()
collector.add(chunks)
def feed_data_into_collector(corpus, chunk_len):
# Defining variables
chunk_len = int(chunk_len)
cumulative = '' cumulative = ''
# Breaking the data into chunks and adding those to the db
cumulative += "Breaking the input dataset...\n\n" cumulative += "Breaking the input dataset...\n\n"
yield cumulative yield cumulative
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)] data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n" cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
yield cumulative yield cumulative
collector.clear() add_chunks_to_collector(data_chunks)
collector.add(data_chunks)
cumulative += "Done." cumulative += "Done."
yield cumulative yield cumulative
@ -123,6 +134,8 @@ def apply_settings(_chunk_count):
def input_modifier(string): def input_modifier(string):
if shared.is_chat():
return string
# Find the user input # Find the user input
pattern = re.compile(r"<\|begin-user-input\|>(.*?)<\|end-user-input\|>", re.DOTALL) pattern = re.compile(r"<\|begin-user-input\|>(.*?)<\|end-user-input\|>", re.DOTALL)
@ -143,6 +156,23 @@ def input_modifier(string):
return string return string
def custom_generate_chat_prompt(user_input, state, **kwargs):
if len(shared.history['internal']) > 2 and user_input != '':
chunks = []
for i in range(len(shared.history['internal'])-1):
chunks.append('\n'.join(shared.history['internal'][i]))
add_chunks_to_collector(chunks)
query = '\n'.join(shared.history['internal'][-1] + [user_input])
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
# Sort the history by relevance instead of by chronological order,
# except for the latest message
state['history'] = [shared.history['internal'][id_] for id_ in best_ids[::-1]] + [shared.history['internal'][-1]]
return chat.generate_chat_prompt(user_input, state, **kwargs)
def ui(): def ui():
with gr.Accordion("Click for more information...", open=False): with gr.Accordion("Click for more information...", open=False):
gr.Markdown(textwrap.dedent(""" gr.Markdown(textwrap.dedent("""
@ -156,7 +186,9 @@ def ui():
It is a modified version of the superbig extension by kaiokendev: https://github.com/kaiokendev/superbig It is a modified version of the superbig extension by kaiokendev: https://github.com/kaiokendev/superbig
## How to use it ## Notebook/default modes
### How to use it
1) Paste your input text (of whatever length) into the text box below. 1) Paste your input text (of whatever length) into the text box below.
2) Click on "Load data" to feed this text into the Chroma database. 2) Click on "Load data" to feed this text into the Chroma database.
@ -166,7 +198,7 @@ def ui():
The special tokens mentioned above (`<|begin-user-input|>`, `<|end-user-input|>`, and `<|injection-point|>`) are removed when the injection happens. The special tokens mentioned above (`<|begin-user-input|>`, `<|end-user-input|>`, and `<|injection-point|>`) are removed when the injection happens.
## Example ### Example
For your convenience, you can use the following prompt as a starting point (for Alpaca models): For your convenience, you can use the following prompt as a starting point (for Alpaca models):
@ -186,14 +218,25 @@ def ui():
### Response: ### Response:
``` ```
## Chat mode
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
That is, the prompt will include (starting from the end):
* Your input
* The latest input/reply pair
* The #1 most relevant input/reply pair prior to the latest
* The #2 most relevant input/reply pair prior to the latest
* Etc
This way, the bot can have a long term history.
*This extension is currently experimental and under development.* *This extension is currently experimental and under development.*
""")) """))
if shared.is_chat(): if not shared.is_chat():
# Chat mode has to be handled differently, probably using a custom_generate_chat_prompt
pass
else:
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Tab("Text input"): with gr.Tab("Text input"):

View File

@ -27,6 +27,11 @@ def replace_all(text, dic):
def generate_chat_prompt(user_input, state, **kwargs): def generate_chat_prompt(user_input, state, **kwargs):
# Check if an extension is sending its modified history.
# If not, use the regular history
history = state['history'] if 'history' in state else shared.history['internal']
# Define some variables
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False _continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
@ -61,14 +66,14 @@ def generate_chat_prompt(user_input, state, **kwargs):
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements) bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
# Building the prompt # Building the prompt
i = len(shared.history['internal']) - 1 i = len(history) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length: while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1: if _continue and i == len(history) - 1:
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip()) rows.insert(1, bot_turn_stripped + history[i][1].strip())
else: else:
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip())) rows.insert(1, bot_turn.replace('<|bot-message|>', history[i][1].strip()))
string = shared.history['internal'][i][0] string = history[i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)})) rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
@ -80,7 +85,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
elif not _continue: elif not _continue:
# Adding the user message # Adding the user message
if len(user_input) > 0: if len(user_input) > 0:
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))})) rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
# Adding the Character prefix # Adding the Character prefix
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' '))) rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))