mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Sort selected superbooga chunks by insertion order
For better coherence
This commit is contained in:
parent
b07f849e41
commit
897fa60069
@ -45,15 +45,32 @@ class ChromaCollector(Collecter):
|
|||||||
self.ids = [f"id{i}" for i in range(len(texts))]
|
self.ids = [f"id{i}" for i in range(len(texts))]
|
||||||
self.collection.add(documents=texts, ids=self.ids)
|
self.collection.add(documents=texts, ids=self.ids)
|
||||||
|
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
|
||||||
n_results = min(len(self.ids), n_results)
|
n_results = min(len(self.ids), n_results)
|
||||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
|
||||||
return result
|
documents = result['documents'][0]
|
||||||
|
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
|
||||||
|
return documents, ids
|
||||||
|
|
||||||
|
# Get chunks by similarity
|
||||||
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
|
documents, _ = self.get_documents_and_ids(search_strings, n_results)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
# Get ids by similarity
|
||||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
n_results = min(len(self.ids), n_results)
|
_ , ids = self.get_documents_and_ids(search_strings, n_results)
|
||||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
return ids
|
||||||
return list(map(lambda x: int(x[2:]), result))
|
|
||||||
|
# Get chunks by similarity and then sort by insertion order
|
||||||
|
def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
|
documents, ids = self.get_documents_and_ids(search_strings, n_results)
|
||||||
|
return [x for _, x in sorted(zip(ids, documents))]
|
||||||
|
|
||||||
|
# Get ids by similarity and then sort by insertion order
|
||||||
|
def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
|
_ , ids = self.get_documents_and_ids(search_strings, n_results)
|
||||||
|
return sorted(ids)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.collection.delete(ids=self.ids)
|
self.collection.delete(ids=self.ids)
|
||||||
|
@ -96,7 +96,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
global chat_collector
|
global chat_collector
|
||||||
|
|
||||||
if state['mode'] == 'instruct':
|
if state['mode'] == 'instruct':
|
||||||
results = collector.get(user_input, n_results=chunk_count)
|
results = collector.get_sorted(user_input, n_results=chunk_count)
|
||||||
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
||||||
user_input += additional_context
|
user_input += additional_context
|
||||||
else:
|
else:
|
||||||
@ -116,7 +116,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
add_chunks_to_collector(chunks, chat_collector)
|
add_chunks_to_collector(chunks, chat_collector)
|
||||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||||
try:
|
try:
|
||||||
best_ids = chat_collector.get_ids(query, n_results=chunk_count)
|
best_ids = chat_collector.get_ids_sorted(query, n_results=chunk_count)
|
||||||
additional_context = '\n'
|
additional_context = '\n'
|
||||||
for id_ in best_ids:
|
for id_ in best_ids:
|
||||||
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
@ -147,7 +147,7 @@ def input_modifier(string):
|
|||||||
user_input = match.group(1).strip()
|
user_input = match.group(1).strip()
|
||||||
|
|
||||||
# Get the most similar chunks
|
# Get the most similar chunks
|
||||||
results = collector.get(user_input, n_results=chunk_count)
|
results = collector.get_sorted(user_input, n_results=chunk_count)
|
||||||
|
|
||||||
# Make the injection
|
# Make the injection
|
||||||
string = string.replace('<|injection-point|>', '\n'.join(results))
|
string = string.replace('<|injection-point|>', '\n'.join(results))
|
||||||
|
Loading…
Reference in New Issue
Block a user