text-generation-webui/extensions/superbooga/chromadb.py

102 lines
3.0 KiB
Python
Raw Normal View History

import chromadb
2023-05-13 19:14:59 +02:00
import posthog
import torch
from chromadb.config import Settings
2023-05-13 19:14:59 +02:00
from sentence_transformers import SentenceTransformer
from modules.logging_colors import logger
2023-05-13 19:14:59 +02:00
logger.info('Intercepting all calls to posthog :)')
2023-05-13 19:14:59 +02:00
posthog.capture = lambda *args, **kwargs: None
class Collecter():
def __init__(self):
pass
def add(self, texts: list[str]):
pass
def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
def clear(self):
pass
class Embedder():
def __init__(self):
pass
def embed(self, text: str) -> list[torch.Tensor]:
pass
class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder):
super().__init__()
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
self.ids = []
def add(self, texts: list[str]):
if len(texts) == 0:
return
2023-05-13 19:14:59 +02:00
self.ids = [f"id{i}" for i in range(len(texts))]
self.collection.add(documents=texts, ids=self.ids)
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
2023-05-13 19:14:59 +02:00
n_results = min(len(self.ids), n_results)
if n_results == 0:
return [], []
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
documents = result['documents'][0]
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
return documents, ids
# Get chunks by similarity
def get(self, search_strings: list[str], n_results: int) -> list[str]:
documents, _ = self.get_documents_and_ids(search_strings, n_results)
return documents
2023-05-13 19:14:59 +02:00
# Get ids by similarity
2023-05-13 19:14:59 +02:00
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
_, ids = self.get_documents_and_ids(search_strings, n_results)
return ids
# Get chunks by similarity and then sort by insertion order
def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
documents, ids = self.get_documents_and_ids(search_strings, n_results)
return [x for _, x in sorted(zip(ids, documents))]
# Get ids by similarity and then sort by insertion order
def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
_, ids = self.get_documents_and_ids(search_strings, n_results)
return sorted(ids)
2023-05-13 19:14:59 +02:00
def clear(self):
self.collection.delete(ids=self.ids)
self.ids = []
2023-05-13 19:14:59 +02:00
class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
def make_collector():
global embedder
return ChromaCollector(embedder)
def add_chunks_to_collector(chunks, collector):
collector.clear()
collector.add(chunks)
embedder = SentenceTransformerEmbedder()