Fix issue where n_results/k > index size (#1929)

This commit is contained in:
kaiokendev 2023-05-08 20:16:00 -04:00 committed by GitHub
parent 68dcbc7ebd
commit 0e27b660e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,10 +52,12 @@ class ChromaCollector(Collecter):
self.collection.add(documents=texts, ids=self.ids) self.collection.add(documents=texts, ids=self.ids)
def get(self, search_strings: list[str], n_results: int) -> list[str]: def get(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0] result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
return result return result
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0] result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x : int(x[2:]), result)) return list(map(lambda x : int(x[2:]), result))