diff --git a/extensions/superbooga/chromadb.py b/extensions/superbooga/chromadb.py index 8675607d..52f4854b 100644 --- a/extensions/superbooga/chromadb.py +++ b/extensions/superbooga/chromadb.py @@ -45,15 +45,32 @@ class ChromaCollector(Collecter): self.ids = [f"id{i}" for i in range(len(texts))] self.collection.add(documents=texts, ids=self.ids) - def get(self, search_strings: list[str], n_results: int) -> list[str]: + def get_documents_and_ids(self, search_strings: list[str], n_results: int): n_results = min(len(self.ids), n_results) - result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0] - return result + result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents']) + documents = result['documents'][0] + ids = list(map(lambda x: int(x[2:]), result['ids'][0])) + return documents, ids + # Get chunks by similarity + def get(self, search_strings: list[str], n_results: int) -> list[str]: + documents, _ = self.get_documents_and_ids(search_strings, n_results) + return documents + + # Get ids by similarity def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: - n_results = min(len(self.ids), n_results) - result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0] - return list(map(lambda x: int(x[2:]), result)) + _ , ids = self.get_documents_and_ids(search_strings, n_results) + return ids + + # Get chunks by similarity and then sort by insertion order + def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]: + documents, ids = self.get_documents_and_ids(search_strings, n_results) + return [x for _, x in sorted(zip(ids, documents))] + + # Get ids by similarity and then sort by insertion order + def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]: + _ , ids = self.get_documents_and_ids(search_strings, n_results) + return sorted(ids) def clear(self): self.collection.delete(ids=self.ids) diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index cc8454da..a1d66add 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -96,7 +96,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): global chat_collector if state['mode'] == 'instruct': - results = collector.get(user_input, n_results=chunk_count) + results = collector.get_sorted(user_input, n_results=chunk_count) additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results) user_input += additional_context else: @@ -116,7 +116,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): add_chunks_to_collector(chunks, chat_collector) query = '\n'.join(shared.history['internal'][-1] + [user_input]) try: - best_ids = chat_collector.get_ids(query, n_results=chunk_count) + best_ids = chat_collector.get_ids_sorted(query, n_results=chunk_count) additional_context = '\n' for id_ in best_ids: if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': @@ -147,7 +147,7 @@ def input_modifier(string): user_input = match.group(1).strip() # Get the most similar chunks - results = collector.get(user_input, n_results=chunk_count) + results = collector.get_sorted(user_input, n_results=chunk_count) # Make the injection string = string.replace('<|injection-point|>', '\n'.join(results))