mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add superbooga time weighted history retrieval (#2080)
This commit is contained in:
parent
a04266161d
commit
ee674afa50
@ -47,34 +47,58 @@ class ChromaCollector(Collecter):
|
|||||||
self.ids = [f"id{i}" for i in range(len(texts))]
|
self.ids = [f"id{i}" for i in range(len(texts))]
|
||||||
self.collection.add(documents=texts, ids=self.ids)
|
self.collection.add(documents=texts, ids=self.ids)
|
||||||
|
|
||||||
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
|
def get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
||||||
n_results = min(len(self.ids), n_results)
|
n_results = min(len(self.ids), n_results)
|
||||||
if n_results == 0:
|
if n_results == 0:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents', 'distances'])
|
||||||
documents = result['documents'][0]
|
documents = result['documents'][0]
|
||||||
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
|
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
|
||||||
return documents, ids
|
distances = result['distances'][0]
|
||||||
|
return documents, ids, distances
|
||||||
|
|
||||||
# Get chunks by similarity
|
# Get chunks by similarity
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
documents, _ = self.get_documents_and_ids(search_strings, n_results)
|
documents, _, _ = self.get_documents_ids_distances(search_strings, n_results)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
# Get ids by similarity
|
# Get ids by similarity
|
||||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
_, ids = self.get_documents_and_ids(search_strings, n_results)
|
_, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
# Get chunks by similarity and then sort by insertion order
|
# Get chunks by similarity and then sort by insertion order
|
||||||
def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
documents, ids = self.get_documents_and_ids(search_strings, n_results)
|
documents, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
|
||||||
return [x for _, x in sorted(zip(ids, documents))]
|
return [x for _, x in sorted(zip(ids, documents))]
|
||||||
|
|
||||||
|
# Multiply distance by factor within [0, time_weight] where more recent is lower
|
||||||
|
def apply_time_weight_to_distances(self, ids: list[int], distances: list[float], time_weight: float = 1.0) -> list[float]:
|
||||||
|
if len(self.ids) <= 1:
|
||||||
|
return distances.copy()
|
||||||
|
|
||||||
|
return [distance * (1 - _id / (len(self.ids) - 1) * time_weight) for _id, distance in zip(ids, distances)]
|
||||||
|
|
||||||
# Get ids by similarity and then sort by insertion order
|
# Get ids by similarity and then sort by insertion order
|
||||||
def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_ids_sorted(self, search_strings: list[str], n_results: int, n_initial: int = None, time_weight: float = 1.0) -> list[str]:
|
||||||
_, ids = self.get_documents_and_ids(search_strings, n_results)
|
do_time_weight = time_weight > 0
|
||||||
|
if not (do_time_weight and n_initial is not None):
|
||||||
|
n_initial = n_results
|
||||||
|
elif n_initial == -1:
|
||||||
|
n_initial = len(self.ids)
|
||||||
|
|
||||||
|
if n_initial < n_results:
|
||||||
|
raise ValueError(f"n_initial {n_initial} should be >= n_results {n_results}")
|
||||||
|
|
||||||
|
_, ids, distances = self.get_documents_ids_distances(search_strings, n_initial)
|
||||||
|
if do_time_weight:
|
||||||
|
distances_w = self.apply_time_weight_to_distances(ids, distances, time_weight=time_weight)
|
||||||
|
results = zip(ids, distances, distances_w)
|
||||||
|
results = sorted(results, key=lambda x: x[2])[:n_results]
|
||||||
|
results = sorted(results, key=lambda x: x[0])
|
||||||
|
ids = [x[0] for x in results]
|
||||||
|
|
||||||
return sorted(ids)
|
return sorted(ids)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
@ -12,6 +12,8 @@ from .download_urls import download_urls
|
|||||||
|
|
||||||
params = {
|
params = {
|
||||||
'chunk_count': 5,
|
'chunk_count': 5,
|
||||||
|
'chunk_count_initial': 10,
|
||||||
|
'time_weight': 0,
|
||||||
'chunk_length': 700,
|
'chunk_length': 700,
|
||||||
'chunk_separator': '',
|
'chunk_separator': '',
|
||||||
'strong_cleanup': False,
|
'strong_cleanup': False,
|
||||||
@ -20,7 +22,6 @@ params = {
|
|||||||
|
|
||||||
collector = make_collector()
|
collector = make_collector()
|
||||||
chat_collector = make_collector()
|
chat_collector = make_collector()
|
||||||
chunk_count = 5
|
|
||||||
|
|
||||||
|
|
||||||
def feed_data_into_collector(corpus, chunk_len, chunk_sep):
|
def feed_data_into_collector(corpus, chunk_len, chunk_sep):
|
||||||
@ -83,13 +84,12 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
|||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
|
||||||
def apply_settings(_chunk_count):
|
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||||
global chunk_count
|
global params
|
||||||
chunk_count = int(_chunk_count)
|
params['chunk_count'] = int(chunk_count)
|
||||||
settings_to_display = {
|
params['chunk_count_initial'] = int(chunk_count_initial)
|
||||||
'chunk_count': chunk_count,
|
params['time_weight'] = time_weight
|
||||||
}
|
settings_to_display = {k: params[k] for k in params if k in ['chunk_count', 'chunk_count_initial', 'time_weight']}
|
||||||
|
|
||||||
yield f"The following settings are now active: {str(settings_to_display)}"
|
yield f"The following settings are now active: {str(settings_to_display)}"
|
||||||
|
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
global chat_collector
|
global chat_collector
|
||||||
|
|
||||||
if state['mode'] == 'instruct':
|
if state['mode'] == 'instruct':
|
||||||
results = collector.get_sorted(user_input, n_results=chunk_count)
|
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||||
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
||||||
user_input += additional_context
|
user_input += additional_context
|
||||||
else:
|
else:
|
||||||
@ -108,7 +108,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
|
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
|
||||||
return output
|
return output
|
||||||
|
|
||||||
if len(shared.history['internal']) > chunk_count and user_input != '':
|
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
|
||||||
chunks = []
|
chunks = []
|
||||||
hist_size = len(shared.history['internal'])
|
hist_size = len(shared.history['internal'])
|
||||||
for i in range(hist_size-1):
|
for i in range(hist_size-1):
|
||||||
@ -117,7 +117,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
add_chunks_to_collector(chunks, chat_collector)
|
add_chunks_to_collector(chunks, chat_collector)
|
||||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||||
try:
|
try:
|
||||||
best_ids = chat_collector.get_ids_sorted(query, n_results=chunk_count)
|
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
|
||||||
additional_context = '\n'
|
additional_context = '\n'
|
||||||
for id_ in best_ids:
|
for id_ in best_ids:
|
||||||
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
@ -151,7 +151,7 @@ def input_modifier(string):
|
|||||||
user_input = match.group(1).strip()
|
user_input = match.group(1).strip()
|
||||||
|
|
||||||
# Get the most similar chunks
|
# Get the most similar chunks
|
||||||
results = collector.get_sorted(user_input, n_results=chunk_count)
|
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||||
|
|
||||||
# Make the injection
|
# Make the injection
|
||||||
string = string.replace('<|injection-point|>', '\n'.join(results))
|
string = string.replace('<|injection-point|>', '\n'.join(results))
|
||||||
@ -240,6 +240,10 @@ def ui():
|
|||||||
|
|
||||||
with gr.Tab("Generation settings"):
|
with gr.Tab("Generation settings"):
|
||||||
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
|
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
|
||||||
|
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
|
||||||
|
time_weight = gr.Slider(0, 1, value=params['time_weight'], label='Time weight', info='Defines the strength of the time weighting. 0 = no time weighting.')
|
||||||
|
chunk_count_initial = gr.Number(value=params['chunk_count_initial'], label='Initial chunk count', info='The number of closest-matching chunks retrieved for time weight reordering in chat mode. This should be >= chunk count. -1 = All chunks are retrieved. Only used if time_weight > 0.')
|
||||||
|
|
||||||
update_settings = gr.Button('Apply changes')
|
update_settings = gr.Button('Apply changes')
|
||||||
|
|
||||||
chunk_len = gr.Number(value=params['chunk_length'], label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
|
chunk_len = gr.Number(value=params['chunk_length'], label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
|
||||||
@ -250,4 +254,4 @@ def ui():
|
|||||||
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||||
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
|
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
|
||||||
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||||
update_settings.click(apply_settings, [chunk_count], last_updated, show_progress=False)
|
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user