mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-01 07:00:15 +01:00
376 lines
15 KiB
Python
376 lines
15 KiB
Python
|
import threading
|
||
|
import chromadb
|
||
|
import posthog
|
||
|
import torch
|
||
|
import math
|
||
|
|
||
|
import numpy as np
|
||
|
import extensions.superboogav2.parameters as parameters
|
||
|
|
||
|
from chromadb.config import Settings
|
||
|
from sentence_transformers import SentenceTransformer
|
||
|
|
||
|
from modules.logging_colors import logger
|
||
|
from modules.text_generation import encode, decode
|
||
|
|
||
|
logger.debug('Intercepting all calls to posthog.')
|
||
|
posthog.capture = lambda *args, **kwargs: None
|
||
|
|
||
|
|
||
|
class Collecter():
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]):
|
||
|
pass
|
||
|
|
||
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||
|
pass
|
||
|
|
||
|
def clear(self):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Embedder():
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def embed(self, text: str) -> list[torch.Tensor]:
|
||
|
pass
|
||
|
|
||
|
class Info:
|
||
|
def __init__(self, start_index, text_with_context, distance, id):
|
||
|
self.text_with_context = text_with_context
|
||
|
self.start_index = start_index
|
||
|
self.distance = distance
|
||
|
self.id = id
|
||
|
|
||
|
def calculate_distance(self, other_info):
|
||
|
if parameters.get_new_dist_strategy() == parameters.DIST_MIN_STRATEGY:
|
||
|
# Min
|
||
|
return min(self.distance, other_info.distance)
|
||
|
elif parameters.get_new_dist_strategy() == parameters.DIST_HARMONIC_STRATEGY:
|
||
|
# Harmonic mean
|
||
|
return 2 * (self.distance * other_info.distance) / (self.distance + other_info.distance)
|
||
|
elif parameters.get_new_dist_strategy() == parameters.DIST_GEOMETRIC_STRATEGY:
|
||
|
# Geometric mean
|
||
|
return (self.distance * other_info.distance) ** 0.5
|
||
|
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
|
||
|
# Arithmetic mean
|
||
|
return (self.distance + other_info.distance) / 2
|
||
|
else: # Min is default
|
||
|
return min(self.distance, other_info.distance)
|
||
|
|
||
|
def merge_with(self, other_info):
|
||
|
s1 = self.text_with_context
|
||
|
s2 = other_info.text_with_context
|
||
|
s1_start = self.start_index
|
||
|
s2_start = other_info.start_index
|
||
|
|
||
|
new_dist = self.calculate_distance(other_info)
|
||
|
|
||
|
if self.should_merge(s1, s2, s1_start, s2_start):
|
||
|
if s1_start <= s2_start:
|
||
|
if s1_start + len(s1) >= s2_start + len(s2): # if s1 completely covers s2
|
||
|
return Info(s1_start, s1, new_dist, self.id)
|
||
|
else:
|
||
|
overlap = max(0, s1_start + len(s1) - s2_start)
|
||
|
return Info(s1_start, s1 + s2[overlap:], new_dist, self.id)
|
||
|
else:
|
||
|
if s2_start + len(s2) >= s1_start + len(s1): # if s2 completely covers s1
|
||
|
return Info(s2_start, s2, new_dist, other_info.id)
|
||
|
else:
|
||
|
overlap = max(0, s2_start + len(s2) - s1_start)
|
||
|
return Info(s2_start, s2 + s1[overlap:], new_dist, other_info.id)
|
||
|
|
||
|
return None
|
||
|
|
||
|
@staticmethod
|
||
|
def should_merge(s1, s2, s1_start, s2_start):
|
||
|
# Check if s1 and s2 are adjacent or overlapping
|
||
|
s1_end = s1_start + len(s1)
|
||
|
s2_end = s2_start + len(s2)
|
||
|
|
||
|
return not (s1_end < s2_start or s2_end < s1_start)
|
||
|
|
||
|
class ChromaCollector(Collecter):
|
||
|
def __init__(self, embedder: Embedder):
|
||
|
super().__init__()
|
||
|
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||
|
self.embedder = embedder
|
||
|
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed)
|
||
|
self.ids = []
|
||
|
self.id_to_info = {}
|
||
|
self.embeddings_cache = {}
|
||
|
self.lock = threading.Lock() # Locking so the server doesn't break.
|
||
|
|
||
|
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
|
||
|
with self.lock:
|
||
|
assert metadatas is None or len(metadatas) == len(texts), "metadatas must be None or have the same length as texts"
|
||
|
|
||
|
if len(texts) == 0:
|
||
|
return
|
||
|
|
||
|
new_ids = self._get_new_ids(len(texts))
|
||
|
|
||
|
(existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||
|
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
|
||
|
|
||
|
# If there are any already existing texts, add them all at once.
|
||
|
if existing_texts:
|
||
|
logger.info(f'Adding {len(existing_embeddings)} cached embeddings.')
|
||
|
args = {'embeddings': existing_embeddings, 'documents': existing_texts, 'ids': existing_ids}
|
||
|
if metadatas is not None:
|
||
|
args['metadatas'] = existing_metas
|
||
|
self.collection.add(**args)
|
||
|
|
||
|
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
|
||
|
if non_existing_texts:
|
||
|
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist()
|
||
|
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
|
||
|
self.embeddings_cache[text] = embedding
|
||
|
|
||
|
logger.info(f'Adding {len(non_existing_embeddings)} new embeddings.')
|
||
|
args = {'embeddings': non_existing_embeddings, 'documents': non_existing_texts, 'ids': non_existing_ids}
|
||
|
if metadatas is not None:
|
||
|
args['metadatas'] = non_existing_metas
|
||
|
self.collection.add(**args)
|
||
|
|
||
|
# Create a dictionary that maps each ID to its context and starting index
|
||
|
new_info = {
|
||
|
id_: {'text_with_context': context, 'start_index': start_index}
|
||
|
for id_, context, start_index in zip(new_ids, texts_with_context, starting_indices)
|
||
|
}
|
||
|
|
||
|
self.id_to_info.update(new_info)
|
||
|
self.ids.extend(new_ids)
|
||
|
|
||
|
|
||
|
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
|
||
|
existing_texts, non_existing_texts = [], []
|
||
|
existing_embeddings = []
|
||
|
existing_ids, non_existing_ids = [], []
|
||
|
existing_metas, non_existing_metas = [], []
|
||
|
|
||
|
for i, text in enumerate(texts):
|
||
|
id_ = new_ids[i]
|
||
|
metadata = metadatas[i] if metadatas is not None else None
|
||
|
embedding = self.embeddings_cache.get(text)
|
||
|
if embedding:
|
||
|
existing_texts.append(text)
|
||
|
existing_embeddings.append(embedding)
|
||
|
existing_ids.append(id_)
|
||
|
existing_metas.append(metadata)
|
||
|
else:
|
||
|
non_existing_texts.append(text)
|
||
|
non_existing_ids.append(id_)
|
||
|
non_existing_metas.append(metadata)
|
||
|
|
||
|
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||
|
(non_existing_texts, non_existing_ids, non_existing_metas)
|
||
|
|
||
|
|
||
|
def _get_new_ids(self, num_new_ids: int):
|
||
|
if self.ids:
|
||
|
max_existing_id = max(int(id_) for id_ in self.ids)
|
||
|
else:
|
||
|
max_existing_id = -1
|
||
|
|
||
|
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
|
||
|
|
||
|
|
||
|
def _find_min_max_start_index(self):
|
||
|
max_index, min_index = 0, float('inf')
|
||
|
for _, val in self.id_to_info.items():
|
||
|
if val['start_index'] > max_index:
|
||
|
max_index = val['start_index']
|
||
|
if val['start_index'] < min_index:
|
||
|
min_index = val['start_index']
|
||
|
return min_index, max_index
|
||
|
|
||
|
|
||
|
# NB: Does not make sense to weigh excerpts from different documents.
|
||
|
# But let's say that's the user's problem. Perfect world scenario:
|
||
|
# Apply time weighing to different documents. For each document, then, add
|
||
|
# separate time weighing.
|
||
|
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
|
||
|
sigmoid = lambda x: 1 / (1 + np.exp(-x))
|
||
|
|
||
|
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
|
||
|
|
||
|
# Scale to [0,time_power] and shift it up to [1-time_power, 1]
|
||
|
weights = weights - min(weights)
|
||
|
weights = weights * (time_power / max(weights))
|
||
|
weights = weights + (1 - time_power)
|
||
|
|
||
|
# Reverse the weights
|
||
|
weights = weights[::-1]
|
||
|
|
||
|
for info in infos:
|
||
|
index = info.start_index
|
||
|
info.distance *= weights[index]
|
||
|
|
||
|
|
||
|
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
|
||
|
# Ensure there are infos to filter
|
||
|
if not infos:
|
||
|
return []
|
||
|
|
||
|
# Find info with minimum distance
|
||
|
min_info = min(infos, key=lambda x: x.distance)
|
||
|
|
||
|
# Calculate median distance among infos
|
||
|
median_distance = np.median([inf.distance for inf in infos])
|
||
|
|
||
|
# Filter out infos that have a distance significantly greater than the median
|
||
|
filtered_infos = [inf for inf in infos if inf.distance <= significant_level * median_distance]
|
||
|
|
||
|
# Always include the info with minimum distance
|
||
|
if min_info not in filtered_infos:
|
||
|
filtered_infos.append(min_info)
|
||
|
|
||
|
return filtered_infos
|
||
|
|
||
|
|
||
|
def _merge_infos(self, infos: list[Info]):
|
||
|
merged_infos = []
|
||
|
current_info = infos[0]
|
||
|
|
||
|
for next_info in infos[1:]:
|
||
|
merged = current_info.merge_with(next_info)
|
||
|
if merged is not None:
|
||
|
current_info = merged
|
||
|
else:
|
||
|
merged_infos.append(current_info)
|
||
|
current_info = next_info
|
||
|
|
||
|
merged_infos.append(current_info)
|
||
|
return merged_infos
|
||
|
|
||
|
|
||
|
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
|
||
|
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
||
|
n_results = min(len(self.ids), n_results)
|
||
|
if n_results == 0:
|
||
|
return [], [], []
|
||
|
|
||
|
if isinstance(search_strings, str):
|
||
|
search_strings = [search_strings]
|
||
|
|
||
|
infos = []
|
||
|
min_start_index, max_start_index = self._find_min_max_start_index()
|
||
|
|
||
|
for search_string in search_strings:
|
||
|
result = self.collection.query(query_texts=search_string, n_results=math.ceil(n_results / len(search_strings)), include=['distances'])
|
||
|
curr_infos = [Info(start_index=self.id_to_info[id]['start_index'],
|
||
|
text_with_context=self.id_to_info[id]['text_with_context'],
|
||
|
distance=distance, id=id)
|
||
|
for id, distance in zip(result['ids'][0], result['distances'][0])]
|
||
|
|
||
|
self._apply_sigmoid_time_weighing(infos=curr_infos, document_len=max_start_index - min_start_index + 1, time_steepness=parameters.get_time_steepness(), time_power=parameters.get_time_power())
|
||
|
curr_infos = self._filter_outliers_by_median_distance(curr_infos, parameters.get_significant_level())
|
||
|
infos.extend(curr_infos)
|
||
|
|
||
|
infos.sort(key=lambda x: x.start_index)
|
||
|
infos = self._merge_infos(infos)
|
||
|
|
||
|
texts_with_context = [inf.text_with_context for inf in infos]
|
||
|
ids = [inf.id for inf in infos]
|
||
|
distances = [inf.distance for inf in infos]
|
||
|
|
||
|
return texts_with_context, ids, distances
|
||
|
|
||
|
|
||
|
# Get chunks by similarity
|
||
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||
|
with self.lock:
|
||
|
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||
|
return documents
|
||
|
|
||
|
|
||
|
# Get ids by similarity
|
||
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||
|
with self.lock:
|
||
|
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||
|
return ids
|
||
|
|
||
|
|
||
|
# Cutoff token count
|
||
|
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
|
||
|
# TODO: Move to caller; We add delimiters there which might go over the limit.
|
||
|
current_token_count = 0
|
||
|
return_documents = []
|
||
|
|
||
|
for doc in documents:
|
||
|
doc_tokens = encode(doc)[0]
|
||
|
doc_token_count = len(doc_tokens)
|
||
|
if current_token_count + doc_token_count > max_token_count:
|
||
|
# If adding this document would exceed the max token count,
|
||
|
# truncate the document to fit within the limit.
|
||
|
remaining_tokens = max_token_count - current_token_count
|
||
|
|
||
|
truncated_doc = decode(doc_tokens[:remaining_tokens], skip_special_tokens=True)
|
||
|
return_documents.append(truncated_doc)
|
||
|
break
|
||
|
else:
|
||
|
return_documents.append(doc)
|
||
|
current_token_count += doc_token_count
|
||
|
|
||
|
return return_documents
|
||
|
|
||
|
|
||
|
# Get chunks by similarity and then sort by ids
|
||
|
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||
|
with self.lock:
|
||
|
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||
|
sorted_docs = [x for _, x in sorted(zip(ids, documents))]
|
||
|
|
||
|
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||
|
|
||
|
|
||
|
# Get chunks by similarity and then sort by distance (lowest distance is last).
|
||
|
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||
|
with self.lock:
|
||
|
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
|
||
|
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
|
||
|
|
||
|
# If a document is truncated or competely skipped, it would be with high distance.
|
||
|
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||
|
return_documents.reverse() # highest -> lowest
|
||
|
|
||
|
return return_documents
|
||
|
|
||
|
|
||
|
def delete(self, ids_to_delete: list[str], where: dict):
|
||
|
with self.lock:
|
||
|
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
|
||
|
self.collection.delete(ids=ids_to_delete, where=where)
|
||
|
|
||
|
# Remove the deleted ids from self.ids and self.id_to_info
|
||
|
ids_set = set(ids_to_delete)
|
||
|
self.ids = [id_ for id_ in self.ids if id_ not in ids_set]
|
||
|
for id_ in ids_to_delete:
|
||
|
self.id_to_info.pop(id_, None)
|
||
|
|
||
|
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
|
||
|
|
||
|
|
||
|
def clear(self):
|
||
|
with self.lock:
|
||
|
self.chroma_client.reset()
|
||
|
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed)
|
||
|
self.ids = []
|
||
|
self.id_to_info = {}
|
||
|
|
||
|
logger.info('Successfully cleared all records and reset chromaDB.')
|
||
|
|
||
|
|
||
|
class SentenceTransformerEmbedder(Embedder):
|
||
|
def __init__(self) -> None:
|
||
|
logger.debug('Creating Sentence Embedder...')
|
||
|
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||
|
self.embed = self.model.encode
|
||
|
|
||
|
|
||
|
def make_collector():
|
||
|
return ChromaCollector(SentenceTransformerEmbedder())
|