mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
parent
ab6acddcc5
commit
9c53517d2c
@ -42,11 +42,17 @@ class ChromaCollector(Collecter):
|
||||
self.ids = []
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
if len(texts) == 0:
|
||||
return
|
||||
|
||||
self.ids = [f"id{i}" for i in range(len(texts))]
|
||||
self.collection.add(documents=texts, ids=self.ids)
|
||||
|
||||
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
|
||||
n_results = min(len(self.ids), n_results)
|
||||
if n_results == 0:
|
||||
return [], []
|
||||
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
|
||||
documents = result['documents'][0]
|
||||
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
|
||||
@ -59,7 +65,7 @@ class ChromaCollector(Collecter):
|
||||
|
||||
# Get ids by similarity
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
_ , ids = self.get_documents_and_ids(search_strings, n_results)
|
||||
_, ids = self.get_documents_and_ids(search_strings, n_results)
|
||||
return ids
|
||||
|
||||
# Get chunks by similarity and then sort by insertion order
|
||||
@ -69,11 +75,12 @@ class ChromaCollector(Collecter):
|
||||
|
||||
# Get ids by similarity and then sort by insertion order
|
||||
def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
_ , ids = self.get_documents_and_ids(search_strings, n_results)
|
||||
_, ids = self.get_documents_and_ids(search_strings, n_results)
|
||||
return sorted(ids)
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
self.ids = []
|
||||
|
||||
|
||||
class SentenceTransformerEmbedder(Embedder):
|
||||
|
Loading…
Reference in New Issue
Block a user