mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
79 lines
2.1 KiB
Python
79 lines
2.1 KiB
Python
|
import logging
|
||
|
|
||
|
import posthog
|
||
|
import torch
|
||
|
from sentence_transformers import SentenceTransformer
|
||
|
|
||
|
import chromadb
|
||
|
from chromadb.config import Settings
|
||
|
|
||
|
logging.info('Intercepting all calls to posthog :)')
|
||
|
posthog.capture = lambda *args, **kwargs: None
|
||
|
|
||
|
|
||
|
class Collecter():
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def add(self, texts: list[str]):
|
||
|
pass
|
||
|
|
||
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||
|
pass
|
||
|
|
||
|
def clear(self):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Embedder():
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def embed(self, text: str) -> list[torch.Tensor]:
|
||
|
pass
|
||
|
|
||
|
|
||
|
class ChromaCollector(Collecter):
|
||
|
def __init__(self, embedder: Embedder):
|
||
|
super().__init__()
|
||
|
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||
|
self.embedder = embedder
|
||
|
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
|
||
|
self.ids = []
|
||
|
|
||
|
def add(self, texts: list[str]):
|
||
|
self.ids = [f"id{i}" for i in range(len(texts))]
|
||
|
self.collection.add(documents=texts, ids=self.ids)
|
||
|
|
||
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||
|
n_results = min(len(self.ids), n_results)
|
||
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
||
|
return result
|
||
|
|
||
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||
|
n_results = min(len(self.ids), n_results)
|
||
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||
|
return list(map(lambda x: int(x[2:]), result))
|
||
|
|
||
|
def clear(self):
|
||
|
self.collection.delete(ids=self.ids)
|
||
|
|
||
|
|
||
|
class SentenceTransformerEmbedder(Embedder):
|
||
|
def __init__(self) -> None:
|
||
|
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||
|
self.embed = self.model.encode
|
||
|
|
||
|
|
||
|
def make_collector():
|
||
|
global embedder
|
||
|
return ChromaCollector(embedder)
|
||
|
|
||
|
|
||
|
def add_chunks_to_collector(chunks, collector):
|
||
|
collector.clear()
|
||
|
collector.add(chunks)
|
||
|
|
||
|
|
||
|
embedder = SentenceTransformerEmbedder()
|