From 6b67cb6611e8b80e1202336637242e49c2575c51 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 7 May 2023 15:01:14 -0300 Subject: [PATCH] Generalize superbooga to chat mode --- extensions/superbooga/script.py | 65 +++++++++++++++++++++++++++------ modules/chat.py | 17 ++++++--- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index 0cc16d50..649cc0d9 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -8,9 +8,10 @@ import posthog import torch from bs4 import BeautifulSoup from chromadb.config import Settings -from modules import shared from sentence_transformers import SentenceTransformer +from modules import chat, shared + print('Intercepting all calls to posthog :)') posthog.capture = lambda *args, **kwargs: None @@ -53,6 +54,10 @@ class ChromaCollector(Collecter): result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0] return result + def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: + result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0] + return list(map(lambda x : int(x[2:]), result)) + def clear(self): self.collection.delete(ids=self.ids) @@ -68,18 +73,24 @@ collector = ChromaCollector(embedder) chunk_count = 5 -def feed_data_into_collector(corpus, chunk_len): +def add_chunks_to_collector(chunks): global collector - chunk_len = int(chunk_len) + collector.clear() + collector.add(chunks) + +def feed_data_into_collector(corpus, chunk_len): + # Defining variables + chunk_len = int(chunk_len) cumulative = '' + + # Breaking the data into chunks and adding those to the db cumulative += "Breaking the input dataset...\n\n" yield cumulative data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)] cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n" yield cumulative - collector.clear() - collector.add(data_chunks) + add_chunks_to_collector(data_chunks) cumulative += "Done." yield cumulative @@ -123,6 +134,8 @@ def apply_settings(_chunk_count): def input_modifier(string): + if shared.is_chat(): + return string # Find the user input pattern = re.compile(r"<\|begin-user-input\|>(.*?)<\|end-user-input\|>", re.DOTALL) @@ -143,6 +156,23 @@ def input_modifier(string): return string +def custom_generate_chat_prompt(user_input, state, **kwargs): + if len(shared.history['internal']) > 2 and user_input != '': + chunks = [] + for i in range(len(shared.history['internal'])-1): + chunks.append('\n'.join(shared.history['internal'][i])) + + add_chunks_to_collector(chunks) + query = '\n'.join(shared.history['internal'][-1] + [user_input]) + best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1) + + # Sort the history by relevance instead of by chronological order, + # except for the latest message + state['history'] = [shared.history['internal'][id_] for id_ in best_ids[::-1]] + [shared.history['internal'][-1]] + + return chat.generate_chat_prompt(user_input, state, **kwargs) + + def ui(): with gr.Accordion("Click for more information...", open=False): gr.Markdown(textwrap.dedent(""" @@ -156,7 +186,9 @@ def ui(): It is a modified version of the superbig extension by kaiokendev: https://github.com/kaiokendev/superbig - ## How to use it + ## Notebook/default modes + + ### How to use it 1) Paste your input text (of whatever length) into the text box below. 2) Click on "Load data" to feed this text into the Chroma database. @@ -166,7 +198,7 @@ def ui(): The special tokens mentioned above (`<|begin-user-input|>`, `<|end-user-input|>`, and `<|injection-point|>`) are removed when the injection happens. - ## Example + ### Example For your convenience, you can use the following prompt as a starting point (for Alpaca models): @@ -186,14 +218,25 @@ def ui(): ### Response: ``` + ## Chat mode + + In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair. + + That is, the prompt will include (starting from the end): + + * Your input + * The latest input/reply pair + * The #1 most relevant input/reply pair prior to the latest + * The #2 most relevant input/reply pair prior to the latest + * Etc + + This way, the bot can have a long term history. + *This extension is currently experimental and under development.* """)) - if shared.is_chat(): - # Chat mode has to be handled differently, probably using a custom_generate_chat_prompt - pass - else: + if not shared.is_chat(): with gr.Row(): with gr.Column(): with gr.Tab("Text input"): diff --git a/modules/chat.py b/modules/chat.py index 98e171b0..2481128a 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -27,6 +27,11 @@ def replace_all(text, dic): def generate_chat_prompt(user_input, state, **kwargs): + # Check if an extension is sending its modified history. + # If not, use the regular history + history = state['history'] if 'history' in state else shared.history['internal'] + + # Define some variables impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False _continue = kwargs['_continue'] if '_continue' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False @@ -61,14 +66,14 @@ def generate_chat_prompt(user_input, state, **kwargs): bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements) # Building the prompt - i = len(shared.history['internal']) - 1 + i = len(history) - 1 while i >= 0 and len(encode(''.join(rows))[0]) < max_length: - if _continue and i == len(shared.history['internal']) - 1: - rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip()) + if _continue and i == len(history) - 1: + rows.insert(1, bot_turn_stripped + history[i][1].strip()) else: - rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip())) + rows.insert(1, bot_turn.replace('<|bot-message|>', history[i][1].strip())) - string = shared.history['internal'][i][0] + string = history[i][0] if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)})) @@ -80,7 +85,7 @@ def generate_chat_prompt(user_input, state, **kwargs): elif not _continue: # Adding the user message if len(user_input) > 0: - rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))})) + rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))})) # Adding the Character prefix rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))