From 7cc17e3f1f14b110a38869afc31b998250d84e5e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 13 May 2023 14:14:59 -0300 Subject: [PATCH] Refactor superbooga --- extensions/superbooga/chromadb.py | 78 +++++++++++++++++++++++++ extensions/superbooga/script.py | 97 ++++--------------------------- 2 files changed, 90 insertions(+), 85 deletions(-) create mode 100644 extensions/superbooga/chromadb.py diff --git a/extensions/superbooga/chromadb.py b/extensions/superbooga/chromadb.py new file mode 100644 index 00000000..8675607d --- /dev/null +++ b/extensions/superbooga/chromadb.py @@ -0,0 +1,78 @@ +import logging + +import posthog +import torch +from sentence_transformers import SentenceTransformer + +import chromadb +from chromadb.config import Settings + +logging.info('Intercepting all calls to posthog :)') +posthog.capture = lambda *args, **kwargs: None + + +class Collecter(): + def __init__(self): + pass + + def add(self, texts: list[str]): + pass + + def get(self, search_strings: list[str], n_results: int) -> list[str]: + pass + + def clear(self): + pass + + +class Embedder(): + def __init__(self): + pass + + def embed(self, text: str) -> list[torch.Tensor]: + pass + + +class ChromaCollector(Collecter): + def __init__(self, embedder: Embedder): + super().__init__() + self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False)) + self.embedder = embedder + self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed) + self.ids = [] + + def add(self, texts: list[str]): + self.ids = [f"id{i}" for i in range(len(texts))] + self.collection.add(documents=texts, ids=self.ids) + + def get(self, search_strings: list[str], n_results: int) -> list[str]: + n_results = min(len(self.ids), n_results) + result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0] + return result + + def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: + n_results = min(len(self.ids), n_results) + result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0] + return list(map(lambda x: int(x[2:]), result)) + + def clear(self): + self.collection.delete(ids=self.ids) + + +class SentenceTransformerEmbedder(Embedder): + def __init__(self) -> None: + self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") + self.embed = self.model.encode + + +def make_collector(): + global embedder + return ChromaCollector(embedder) + + +def add_chunks_to_collector(chunks, collector): + collector.clear() + collector.add(chunks) + + +embedder = SentenceTransformerEmbedder() diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index d409646a..9a83289e 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -2,22 +2,14 @@ import logging import re import textwrap -import chromadb import gradio as gr -import posthog -import torch from bs4 import BeautifulSoup -from chromadb.config import Settings -from sentence_transformers import SentenceTransformer - from modules import chat, shared +from .chromadb import add_chunks_to_collector, make_collector from .download_urls import download_urls -logging.info('Intercepting all calls to posthog :)') -posthog.capture = lambda *args, **kwargs: None -# These parameters are customizable through settings.json params = { 'chunk_count': 5, 'chunk_length': 700, @@ -25,72 +17,11 @@ params = { 'threads': 4, } - -class Collecter(): - def __init__(self): - pass - - def add(self, texts: list[str]): - pass - - def get(self, search_strings: list[str], n_results: int) -> list[str]: - pass - - def clear(self): - pass - - -class Embedder(): - def __init__(self): - pass - - def embed(self, text: str) -> list[torch.Tensor]: - pass - - -class ChromaCollector(Collecter): - def __init__(self, embedder: Embedder): - super().__init__() - self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False)) - self.embedder = embedder - self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed) - self.ids = [] - - def add(self, texts: list[str]): - self.ids = [f"id{i}" for i in range(len(texts))] - self.collection.add(documents=texts, ids=self.ids) - - def get(self, search_strings: list[str], n_results: int) -> list[str]: - n_results = min(len(self.ids), n_results) - result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0] - return result - - def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: - n_results = min(len(self.ids), n_results) - result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0] - return list(map(lambda x: int(x[2:]), result)) - - def clear(self): - self.collection.delete(ids=self.ids) - - -class SentenceTransformerEmbedder(Embedder): - def __init__(self) -> None: - self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") - self.embed = self.model.encode - - -embedder = SentenceTransformerEmbedder() -collector = ChromaCollector(embedder) -chat_collector = ChromaCollector(embedder) +collector = make_collector() +chat_collector = make_collector() chunk_count = 5 -def add_chunks_to_collector(chunks, collector): - collector.clear() - collector.add(chunks) - - def feed_data_into_collector(corpus, chunk_len): global collector @@ -150,6 +81,7 @@ def apply_settings(_chunk_count): settings_to_display = { 'chunk_count': chunk_count, } + yield f"The following settings are now active: {str(settings_to_display)}" @@ -193,10 +125,8 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): def remove_special_tokens(string): - for k in ['<|begin-user-input|>', '<|end-user-input|>', '<|injection-point|>']: - string = string.replace(k, '') - - return string.strip() + pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)' + return re.sub(pattern, '', string) def input_modifier(string): @@ -208,17 +138,14 @@ def input_modifier(string): match = re.search(pattern, string) if match: user_input = match.group(1).strip() - else: - return remove_special_tokens(string) - # Get the most similar chunks - results = collector.get(user_input, n_results=chunk_count) + # Get the most similar chunks + results = collector.get(user_input, n_results=chunk_count) - # Make the replacements - string = string.replace('<|begin-user-input|>', '').replace('<|end-user-input|>', '') - string = string.replace('<|injection-point|>', '\n'.join(results)) + # Make the injection + string = string.replace('<|injection-point|>', '\n'.join(results)) - return string + return remove_special_tokens(string) def ui(): @@ -250,7 +177,7 @@ def ui(): ... ``` - The injection doesn't make it into the chat history. It is only used in the current generation. + The injection doesn't make it into the chat history. It is only used in the current generation. #### Regular chat