mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix issue where n_results/k > index size (#1929)
This commit is contained in:
parent
68dcbc7ebd
commit
0e27b660e8
@ -52,10 +52,12 @@ class ChromaCollector(Collecter):
|
|||||||
self.collection.add(documents=texts, ids=self.ids)
|
self.collection.add(documents=texts, ids=self.ids)
|
||||||
|
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
|
n_results = min(len(self.ids), n_results)
|
||||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
|
n_results = min(len(self.ids), n_results)
|
||||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||||
return list(map(lambda x : int(x[2:]), result))
|
return list(map(lambda x : int(x[2:]), result))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user