diff --git a/extensions/superbooga/chromadb.py b/extensions/superbooga/chromadb.py index 52f4854b..088a6d7a 100644 --- a/extensions/superbooga/chromadb.py +++ b/extensions/superbooga/chromadb.py @@ -42,11 +42,17 @@ class ChromaCollector(Collecter): self.ids = [] def add(self, texts: list[str]): + if len(texts) == 0: + return + self.ids = [f"id{i}" for i in range(len(texts))] self.collection.add(documents=texts, ids=self.ids) def get_documents_and_ids(self, search_strings: list[str], n_results: int): n_results = min(len(self.ids), n_results) + if n_results == 0: + return [], [] + result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents']) documents = result['documents'][0] ids = list(map(lambda x: int(x[2:]), result['ids'][0])) @@ -59,7 +65,7 @@ class ChromaCollector(Collecter): # Get ids by similarity def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: - _ , ids = self.get_documents_and_ids(search_strings, n_results) + _, ids = self.get_documents_and_ids(search_strings, n_results) return ids # Get chunks by similarity and then sort by insertion order @@ -69,11 +75,12 @@ class ChromaCollector(Collecter): # Get ids by similarity and then sort by insertion order def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]: - _ , ids = self.get_documents_and_ids(search_strings, n_results) + _, ids = self.get_documents_and_ids(search_strings, n_results) return sorted(ids) def clear(self): self.collection.delete(ids=self.ids) + self.ids = [] class SentenceTransformerEmbedder(Embedder):