mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Add superbooga time weighted history retrieval (#2080)
This commit is contained in:
parent
a04266161d
commit
ee674afa50
@ -47,34 +47,58 @@ class ChromaCollector(Collecter):
|
||||
self.ids = [f"id{i}" for i in range(len(texts))]
|
||||
self.collection.add(documents=texts, ids=self.ids)
|
||||
|
||||
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
|
||||
def get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
||||
n_results = min(len(self.ids), n_results)
|
||||
if n_results == 0:
|
||||
return [], []
|
||||
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents', 'distances'])
|
||||
documents = result['documents'][0]
|
||||
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
|
||||
return documents, ids
|
||||
distances = result['distances'][0]
|
||||
return documents, ids, distances
|
||||
|
||||
# Get chunks by similarity
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
documents, _ = self.get_documents_and_ids(search_strings, n_results)
|
||||
documents, _, _ = self.get_documents_ids_distances(search_strings, n_results)
|
||||
return documents
|
||||
|
||||
# Get ids by similarity
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
_, ids = self.get_documents_and_ids(search_strings, n_results)
|
||||
_, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
|
||||
return ids
|
||||
|
||||
# Get chunks by similarity and then sort by insertion order
|
||||
def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
documents, ids = self.get_documents_and_ids(search_strings, n_results)
|
||||
documents, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
|
||||
return [x for _, x in sorted(zip(ids, documents))]
|
||||
|
||||
# Multiply distance by factor within [0, time_weight] where more recent is lower
|
||||
def apply_time_weight_to_distances(self, ids: list[int], distances: list[float], time_weight: float = 1.0) -> list[float]:
|
||||
if len(self.ids) <= 1:
|
||||
return distances.copy()
|
||||
|
||||
return [distance * (1 - _id / (len(self.ids) - 1) * time_weight) for _id, distance in zip(ids, distances)]
|
||||
|
||||
# Get ids by similarity and then sort by insertion order
|
||||
def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
_, ids = self.get_documents_and_ids(search_strings, n_results)
|
||||
def get_ids_sorted(self, search_strings: list[str], n_results: int, n_initial: int = None, time_weight: float = 1.0) -> list[str]:
|
||||
do_time_weight = time_weight > 0
|
||||
if not (do_time_weight and n_initial is not None):
|
||||
n_initial = n_results
|
||||
elif n_initial == -1:
|
||||
n_initial = len(self.ids)
|
||||
|
||||
if n_initial < n_results:
|
||||
raise ValueError(f"n_initial {n_initial} should be >= n_results {n_results}")
|
||||
|
||||
_, ids, distances = self.get_documents_ids_distances(search_strings, n_initial)
|
||||
if do_time_weight:
|
||||
distances_w = self.apply_time_weight_to_distances(ids, distances, time_weight=time_weight)
|
||||
results = zip(ids, distances, distances_w)
|
||||
results = sorted(results, key=lambda x: x[2])[:n_results]
|
||||
results = sorted(results, key=lambda x: x[0])
|
||||
ids = [x[0] for x in results]
|
||||
|
||||
return sorted(ids)
|
||||
|
||||
def clear(self):
|
||||
|
@ -12,6 +12,8 @@ from .download_urls import download_urls
|
||||
|
||||
params = {
|
||||
'chunk_count': 5,
|
||||
'chunk_count_initial': 10,
|
||||
'time_weight': 0,
|
||||
'chunk_length': 700,
|
||||
'chunk_separator': '',
|
||||
'strong_cleanup': False,
|
||||
@ -20,7 +22,6 @@ params = {
|
||||
|
||||
collector = make_collector()
|
||||
chat_collector = make_collector()
|
||||
chunk_count = 5
|
||||
|
||||
|
||||
def feed_data_into_collector(corpus, chunk_len, chunk_sep):
|
||||
@ -83,13 +84,12 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
||||
yield i
|
||||
|
||||
|
||||
def apply_settings(_chunk_count):
|
||||
global chunk_count
|
||||
chunk_count = int(_chunk_count)
|
||||
settings_to_display = {
|
||||
'chunk_count': chunk_count,
|
||||
}
|
||||
|
||||
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||
global params
|
||||
params['chunk_count'] = int(chunk_count)
|
||||
params['chunk_count_initial'] = int(chunk_count_initial)
|
||||
params['time_weight'] = time_weight
|
||||
settings_to_display = {k: params[k] for k in params if k in ['chunk_count', 'chunk_count_initial', 'time_weight']}
|
||||
yield f"The following settings are now active: {str(settings_to_display)}"
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
global chat_collector
|
||||
|
||||
if state['mode'] == 'instruct':
|
||||
results = collector.get_sorted(user_input, n_results=chunk_count)
|
||||
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
||||
user_input += additional_context
|
||||
else:
|
||||
@ -108,7 +108,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
|
||||
return output
|
||||
|
||||
if len(shared.history['internal']) > chunk_count and user_input != '':
|
||||
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
|
||||
chunks = []
|
||||
hist_size = len(shared.history['internal'])
|
||||
for i in range(hist_size-1):
|
||||
@ -117,7 +117,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
add_chunks_to_collector(chunks, chat_collector)
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
try:
|
||||
best_ids = chat_collector.get_ids_sorted(query, n_results=chunk_count)
|
||||
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
|
||||
additional_context = '\n'
|
||||
for id_ in best_ids:
|
||||
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
@ -151,7 +151,7 @@ def input_modifier(string):
|
||||
user_input = match.group(1).strip()
|
||||
|
||||
# Get the most similar chunks
|
||||
results = collector.get_sorted(user_input, n_results=chunk_count)
|
||||
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||
|
||||
# Make the injection
|
||||
string = string.replace('<|injection-point|>', '\n'.join(results))
|
||||
@ -240,6 +240,10 @@ def ui():
|
||||
|
||||
with gr.Tab("Generation settings"):
|
||||
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
|
||||
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
|
||||
time_weight = gr.Slider(0, 1, value=params['time_weight'], label='Time weight', info='Defines the strength of the time weighting. 0 = no time weighting.')
|
||||
chunk_count_initial = gr.Number(value=params['chunk_count_initial'], label='Initial chunk count', info='The number of closest-matching chunks retrieved for time weight reordering in chat mode. This should be >= chunk count. -1 = All chunks are retrieved. Only used if time_weight > 0.')
|
||||
|
||||
update_settings = gr.Button('Apply changes')
|
||||
|
||||
chunk_len = gr.Number(value=params['chunk_length'], label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
|
||||
@ -250,4 +254,4 @@ def ui():
|
||||
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
|
||||
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||
update_settings.click(apply_settings, [chunk_count], last_updated, show_progress=False)
|
||||
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)
|
||||
|
Loading…
Reference in New Issue
Block a user