mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-29 21:50:16 +01:00
Make superbooga & superboogav2 functional again (#5656)
This commit is contained in:
parent
bae14c8f13
commit
2681f6f640
@ -1,43 +1,24 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import posthog
|
import posthog
|
||||||
import torch
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from sentence_transformers import SentenceTransformer
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
# Intercept calls to posthog
|
||||||
|
|
||||||
logger.info('Intercepting all calls to posthog :)')
|
|
||||||
posthog.capture = lambda *args, **kwargs: None
|
posthog.capture = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
|
||||||
class Collecter():
|
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaCollector():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
name = ''.join(random.choice('ab') for _ in range(10))
|
||||||
|
|
||||||
def add(self, texts: list[str]):
|
self.name = name
|
||||||
pass
|
|
||||||
|
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Embedder():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def embed(self, text: str) -> list[torch.Tensor]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChromaCollector(Collecter):
|
|
||||||
def __init__(self, embedder: Embedder):
|
|
||||||
super().__init__()
|
|
||||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||||
self.embedder = embedder
|
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
|
|
||||||
self.ids = []
|
self.ids = []
|
||||||
|
|
||||||
def add(self, texts: list[str]):
|
def add(self, texts: list[str]):
|
||||||
@ -102,24 +83,15 @@ class ChromaCollector(Collecter):
|
|||||||
return sorted(ids)
|
return sorted(ids)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.collection.delete(ids=self.ids)
|
|
||||||
self.ids = []
|
self.ids = []
|
||||||
|
self.chroma_client.delete_collection(name=self.name)
|
||||||
|
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||||
class SentenceTransformerEmbedder(Embedder):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
|
||||||
self.embed = self.model.encode
|
|
||||||
|
|
||||||
|
|
||||||
def make_collector():
|
def make_collector():
|
||||||
global embedder
|
return ChromaCollector()
|
||||||
return ChromaCollector(embedder)
|
|
||||||
|
|
||||||
|
|
||||||
def add_chunks_to_collector(chunks, collector):
|
def add_chunks_to_collector(chunks, collector):
|
||||||
collector.clear()
|
collector.clear()
|
||||||
collector.add(chunks)
|
collector.add(chunks)
|
||||||
|
|
||||||
|
|
||||||
embedder = SentenceTransformerEmbedder()
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
beautifulsoup4==4.12.2
|
beautifulsoup4==4.12.2
|
||||||
chromadb==0.3.18
|
chromadb==0.4.24
|
||||||
pandas==2.0.3
|
pandas==2.0.3
|
||||||
posthog==2.4.2
|
posthog==2.4.2
|
||||||
sentence_transformers==2.2.2
|
sentence_transformers==2.2.2
|
||||||
|
@ -12,17 +12,16 @@ This module is responsible for the VectorDB API. It currently supports:
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from urllib.parse import urlparse, parse_qs
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
import extensions.superboogav2.parameters as parameters
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .chromadb import ChromaCollector
|
from .chromadb import ChromaCollector
|
||||||
from .data_processor import process_and_add_to_collector
|
from .data_processor import process_and_add_to_collector
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
|
||||||
|
|
||||||
|
|
||||||
class CustomThreadingHTTPServer(ThreadingHTTPServer):
|
class CustomThreadingHTTPServer(ThreadingHTTPServer):
|
||||||
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
|
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
|
||||||
@ -38,7 +37,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.collector = collector
|
self.collector = collector
|
||||||
super().__init__(request, client_address, server)
|
super().__init__(request, client_address, server)
|
||||||
|
|
||||||
|
|
||||||
def _send_412_error(self, message):
|
def _send_412_error(self, message):
|
||||||
self.send_response(412)
|
self.send_response(412)
|
||||||
self.send_header("Content-type", "application/json")
|
self.send_header("Content-type", "application/json")
|
||||||
@ -46,7 +44,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
response = json.dumps({"error": message})
|
response = json.dumps({"error": message})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def _send_404_error(self):
|
def _send_404_error(self):
|
||||||
self.send_response(404)
|
self.send_response(404)
|
||||||
self.send_header("Content-type", "application/json")
|
self.send_header("Content-type", "application/json")
|
||||||
@ -54,14 +51,12 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
response = json.dumps({"error": "Resource not found"})
|
response = json.dumps({"error": "Resource not found"})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def _send_400_error(self, error_message: str):
|
def _send_400_error(self, error_message: str):
|
||||||
self.send_response(400)
|
self.send_response(400)
|
||||||
self.send_header("Content-type", "application/json")
|
self.send_header("Content-type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
response = json.dumps({"error": error_message})
|
response = json.dumps({"error": error_message})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def _send_200_response(self, message: str):
|
def _send_200_response(self, message: str):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
@ -75,24 +70,21 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
|
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
|
||||||
if sort_param == parameters.SORT_DISTANCE:
|
if sort_param == parameters.SORT_DISTANCE:
|
||||||
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
||||||
elif sort_param == parameters.SORT_ID:
|
elif sort_param == parameters.SORT_ID:
|
||||||
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
|
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
|
||||||
else: # Default is dist
|
else: # Default is dist
|
||||||
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"results": results
|
"results": results
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
self._send_404_error()
|
self._send_404_error()
|
||||||
|
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
try:
|
try:
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
@ -107,7 +99,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if corpus is None:
|
if corpus is None:
|
||||||
self._send_412_error("Missing parameter 'corpus'")
|
self._send_412_error("Missing parameter 'corpus'")
|
||||||
return
|
return
|
||||||
|
|
||||||
clear_before_adding = body.get('clear_before_adding', False)
|
clear_before_adding = body.get('clear_before_adding', False)
|
||||||
metadata = body.get('metadata')
|
metadata = body.get('metadata')
|
||||||
process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata)
|
process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata)
|
||||||
@ -118,7 +110,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if corpus is None:
|
if corpus is None:
|
||||||
self._send_412_error("Missing parameter 'metadata'")
|
self._send_412_error("Missing parameter 'metadata'")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.collector.delete(ids_to_delete=None, where=metadata)
|
self.collector.delete(ids_to_delete=None, where=metadata)
|
||||||
self._send_200_response("Data successfully deleted")
|
self._send_200_response("Data successfully deleted")
|
||||||
|
|
||||||
@ -127,15 +119,15 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if search_strings is None:
|
if search_strings is None:
|
||||||
self._send_412_error("Missing parameter 'search_strings'")
|
self._send_412_error("Missing parameter 'search_strings'")
|
||||||
return
|
return
|
||||||
|
|
||||||
n_results = body.get('n_results')
|
n_results = body.get('n_results')
|
||||||
if n_results is None:
|
if n_results is None:
|
||||||
n_results = parameters.get_chunk_count()
|
n_results = parameters.get_chunk_count()
|
||||||
|
|
||||||
max_token_count = body.get('max_token_count')
|
max_token_count = body.get('max_token_count')
|
||||||
if max_token_count is None:
|
if max_token_count is None:
|
||||||
max_token_count = parameters.get_max_token_count()
|
max_token_count = parameters.get_max_token_count()
|
||||||
|
|
||||||
sort_param = query_params.get('sort', ['distance'])[0]
|
sort_param = query_params.get('sort', ['distance'])[0]
|
||||||
|
|
||||||
results = self._handle_get(search_strings, n_results, max_token_count, sort_param)
|
results = self._handle_get(search_strings, n_results, max_token_count, sort_param)
|
||||||
@ -146,7 +138,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._send_400_error(str(e))
|
self._send_400_error(str(e))
|
||||||
|
|
||||||
|
|
||||||
def do_DELETE(self):
|
def do_DELETE(self):
|
||||||
try:
|
try:
|
||||||
parsed_path = urlparse(self.path)
|
parsed_path = urlparse(self.path)
|
||||||
@ -161,12 +152,10 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._send_400_error(str(e))
|
self._send_400_error(str(e))
|
||||||
|
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def do_OPTIONS(self):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
|
|
||||||
def end_headers(self):
|
def end_headers(self):
|
||||||
self.send_header('Access-Control-Allow-Origin', '*')
|
self.send_header('Access-Control-Allow-Origin', '*')
|
||||||
self.send_header('Access-Control-Allow-Methods', '*')
|
self.send_header('Access-Control-Allow-Methods', '*')
|
||||||
@ -197,11 +186,11 @@ class APIManager:
|
|||||||
|
|
||||||
def stop_server(self):
|
def stop_server(self):
|
||||||
if self.server is not None:
|
if self.server is not None:
|
||||||
logger.info(f'Stopping chromaDB API.')
|
logger.info('Stopping chromaDB API.')
|
||||||
self.server.shutdown()
|
self.server.shutdown()
|
||||||
self.server.server_close()
|
self.server.server_close()
|
||||||
self.server = None
|
self.server = None
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
def is_server_running(self):
|
def is_server_running(self):
|
||||||
return self.is_running
|
return self.is_running
|
||||||
|
@ -9,23 +9,23 @@ The benchmark function will return the score as an integer.
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .data_processor import process_and_add_to_collector, preprocess_text
|
from .data_processor import preprocess_text, process_and_add_to_collector
|
||||||
from .parameters import get_chunk_count, get_max_token_count
|
from .parameters import get_chunk_count, get_max_token_count
|
||||||
from .utils import create_metadata_source
|
from .utils import create_metadata_source
|
||||||
|
|
||||||
|
|
||||||
def benchmark(config_path, collector):
|
def benchmark(config_path, collector):
|
||||||
# Get the current system date
|
# Get the current system date
|
||||||
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
filename = f"benchmark_{sysdate}.txt"
|
filename = f"benchmark_{sysdate}.txt"
|
||||||
|
|
||||||
# Open the log file in append mode
|
# Open the log file in append mode
|
||||||
with open(filename, 'a') as log:
|
with open(filename, 'a') as log:
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
total_points = 0
|
total_points = 0
|
||||||
max_points = 0
|
max_points = 0
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ def benchmark(config_path, collector):
|
|||||||
for question_group in item["questions"]:
|
for question_group in item["questions"]:
|
||||||
question_variants = question_group["question_variants"]
|
question_variants = question_group["question_variants"]
|
||||||
criteria = question_group["criteria"]
|
criteria = question_group["criteria"]
|
||||||
|
|
||||||
for q in question_variants:
|
for q in question_variants:
|
||||||
max_points += len(criteria)
|
max_points += len(criteria)
|
||||||
processed_text = preprocess_text(q)
|
processed_text = preprocess_text(q)
|
||||||
@ -54,7 +54,7 @@ def benchmark(config_path, collector):
|
|||||||
results = collector.get_sorted_by_dist(processed_text, n_results=get_chunk_count(), max_token_count=get_max_token_count())
|
results = collector.get_sorted_by_dist(processed_text, n_results=get_chunk_count(), max_token_count=get_max_token_count())
|
||||||
|
|
||||||
points = 0
|
points = 0
|
||||||
|
|
||||||
for c in criteria:
|
for c in criteria:
|
||||||
for p in results:
|
for p in results:
|
||||||
if c in p:
|
if c in p:
|
||||||
@ -69,4 +69,4 @@ def benchmark(config_path, collector):
|
|||||||
|
|
||||||
print(f'##Total points:\n\n{total_points}/{max_points}', file=log)
|
print(f'##Total points:\n\n{total_points}/{max_points}', file=log)
|
||||||
|
|
||||||
return total_points, max_points
|
return total_points, max_points
|
||||||
|
@ -4,16 +4,17 @@ This module is responsible for modifying the chat prompt and history.
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
from extensions.superboogav2.utils import (
|
||||||
|
create_context_text,
|
||||||
|
create_metadata_source
|
||||||
|
)
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
from modules.text_generation import get_encoded_length
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules.chat import load_character_memoized
|
from modules.chat import load_character_memoized
|
||||||
from extensions.superboogav2.utils import create_context_text, create_metadata_source
|
from modules.logging_colors import logger
|
||||||
|
from modules.text_generation import get_encoded_length
|
||||||
|
|
||||||
from .data_processor import process_and_add_to_collector
|
|
||||||
from .chromadb import ChromaCollector
|
from .chromadb import ChromaCollector
|
||||||
|
from .data_processor import process_and_add_to_collector
|
||||||
|
|
||||||
CHAT_METADATA = create_metadata_source('automatic-chat-insert')
|
CHAT_METADATA = create_metadata_source('automatic-chat-insert')
|
||||||
|
|
||||||
@ -21,17 +22,17 @@ CHAT_METADATA = create_metadata_source('automatic-chat-insert')
|
|||||||
def _remove_tag_if_necessary(user_input: str):
|
def _remove_tag_if_necessary(user_input: str):
|
||||||
if not parameters.get_is_manual():
|
if not parameters.get_is_manual():
|
||||||
return user_input
|
return user_input
|
||||||
|
|
||||||
return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input)
|
return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input)
|
||||||
|
|
||||||
|
|
||||||
def _should_query(input: str):
|
def _should_query(input: str):
|
||||||
if not parameters.get_is_manual():
|
if not parameters.get_is_manual():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE):
|
if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +70,7 @@ def _concatinate_history(history: dict, state: dict):
|
|||||||
if len(exchange) >= 2:
|
if len(exchange) >= 2:
|
||||||
full_history_text += _format_single_exchange(bot_name, exchange[1])
|
full_history_text += _format_single_exchange(bot_name, exchange[1])
|
||||||
|
|
||||||
return full_history_text[:-1] # Remove the last new line.
|
return full_history_text[:-1] # Remove the last new line.
|
||||||
|
|
||||||
|
|
||||||
def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
|
def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
|
||||||
@ -82,20 +83,20 @@ def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
|
|||||||
for i, messages in enumerate(reversed(history['internal'])):
|
for i, messages in enumerate(reversed(history['internal'])):
|
||||||
for j, message in enumerate(reversed(messages)):
|
for j, message in enumerate(reversed(messages)):
|
||||||
num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message))
|
num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message))
|
||||||
|
|
||||||
# TODO: This is an extremely naive solution. A more robust implementation must be made.
|
# TODO: This is an extremely naive solution. A more robust implementation must be made.
|
||||||
if history_tokens + num_context_tokens <= max_len:
|
if history_tokens + num_context_tokens <= max_len:
|
||||||
# This message can be replaced
|
# This message can be replaced
|
||||||
replace_position = (i, j)
|
replace_position = (i, j)
|
||||||
|
|
||||||
history_tokens += num_message_tokens
|
history_tokens += num_message_tokens
|
||||||
|
|
||||||
if replace_position is None:
|
if replace_position is None:
|
||||||
logger.warn("The provided context_text is too long to replace any message in the history.")
|
logger.warn("The provided context_text is too long to replace any message in the history.")
|
||||||
else:
|
else:
|
||||||
# replace the message at replace_position with context_text
|
# replace the message at replace_position with context_text
|
||||||
i, j = replace_position
|
i, j = replace_position
|
||||||
history['internal'][-i-1][-j-1] = context_text
|
history['internal'][-i - 1][-j - 1] = context_text
|
||||||
|
|
||||||
|
|
||||||
def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
|
def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
|
||||||
@ -120,5 +121,5 @@ def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector
|
|||||||
user_input = create_context_text(results) + user_input
|
user_input = create_context_text(results) + user_input
|
||||||
elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT:
|
elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT:
|
||||||
_hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state)
|
_hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state)
|
||||||
|
|
||||||
return chat.generate_chat_prompt(user_input, state, **kwargs)
|
return chat.generate_chat_prompt(user_input, state, **kwargs)
|
||||||
|
@ -1,42 +1,23 @@
|
|||||||
import threading
|
|
||||||
import chromadb
|
|
||||||
import posthog
|
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import chromadb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import extensions.superboogav2.parameters as parameters
|
import posthog
|
||||||
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from sentence_transformers import SentenceTransformer
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
import extensions.superboogav2.parameters as parameters
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import encode, decode
|
from modules.text_generation import decode, encode
|
||||||
|
|
||||||
logger.debug('Intercepting all calls to posthog.')
|
# Intercept calls to posthog
|
||||||
posthog.capture = lambda *args, **kwargs: None
|
posthog.capture = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
|
||||||
class Collecter():
|
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Embedder():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def embed(self, text: str) -> list[torch.Tensor]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Info:
|
class Info:
|
||||||
def __init__(self, start_index, text_with_context, distance, id):
|
def __init__(self, start_index, text_with_context, distance, id):
|
||||||
@ -58,7 +39,7 @@ class Info:
|
|||||||
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
|
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
|
||||||
# Arithmetic mean
|
# Arithmetic mean
|
||||||
return (self.distance + other_info.distance) / 2
|
return (self.distance + other_info.distance) / 2
|
||||||
else: # Min is default
|
else: # Min is default
|
||||||
return min(self.distance, other_info.distance)
|
return min(self.distance, other_info.distance)
|
||||||
|
|
||||||
def merge_with(self, other_info):
|
def merge_with(self, other_info):
|
||||||
@ -66,7 +47,7 @@ class Info:
|
|||||||
s2 = other_info.text_with_context
|
s2 = other_info.text_with_context
|
||||||
s1_start = self.start_index
|
s1_start = self.start_index
|
||||||
s2_start = other_info.start_index
|
s2_start = other_info.start_index
|
||||||
|
|
||||||
new_dist = self.calculate_distance(other_info)
|
new_dist = self.calculate_distance(other_info)
|
||||||
|
|
||||||
if self.should_merge(s1, s2, s1_start, s2_start):
|
if self.should_merge(s1, s2, s1_start, s2_start):
|
||||||
@ -84,55 +65,58 @@ class Info:
|
|||||||
return Info(s2_start, s2 + s1[overlap:], new_dist, other_info.id)
|
return Info(s2_start, s2 + s1[overlap:], new_dist, other_info.id)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def should_merge(s1, s2, s1_start, s2_start):
|
def should_merge(s1, s2, s1_start, s2_start):
|
||||||
# Check if s1 and s2 are adjacent or overlapping
|
# Check if s1 and s2 are adjacent or overlapping
|
||||||
s1_end = s1_start + len(s1)
|
s1_end = s1_start + len(s1)
|
||||||
s2_end = s2_start + len(s2)
|
s2_end = s2_start + len(s2)
|
||||||
|
|
||||||
return not (s1_end < s2_start or s2_end < s1_start)
|
return not (s1_end < s2_start or s2_end < s1_start)
|
||||||
|
|
||||||
class ChromaCollector(Collecter):
|
|
||||||
def __init__(self, embedder: Embedder):
|
class ChromaCollector():
|
||||||
super().__init__()
|
def __init__(self):
|
||||||
|
name = ''.join(random.choice('ab') for _ in range(10))
|
||||||
|
|
||||||
|
self.name = name
|
||||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||||
self.embedder = embedder
|
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed)
|
|
||||||
self.ids = []
|
self.ids = []
|
||||||
self.id_to_info = {}
|
self.id_to_info = {}
|
||||||
self.embeddings_cache = {}
|
self.embeddings_cache = {}
|
||||||
self.lock = threading.Lock() # Locking so the server doesn't break.
|
self.lock = threading.Lock() # Locking so the server doesn't break.
|
||||||
|
|
||||||
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
|
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
assert metadatas is None or len(metadatas) == len(texts), "metadatas must be None or have the same length as texts"
|
assert metadatas is None or len(metadatas) == len(texts), "metadatas must be None or have the same length as texts"
|
||||||
|
|
||||||
if len(texts) == 0:
|
if len(texts) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
new_ids = self._get_new_ids(len(texts))
|
new_ids = self._get_new_ids(len(texts))
|
||||||
|
|
||||||
(existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
(existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||||||
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
|
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
|
||||||
|
|
||||||
# If there are any already existing texts, add them all at once.
|
# If there are any already existing texts, add them all at once.
|
||||||
if existing_texts:
|
if existing_texts:
|
||||||
logger.info(f'Adding {len(existing_embeddings)} cached embeddings.')
|
logger.info(f'Adding {len(existing_embeddings)} cached embeddings.')
|
||||||
args = {'embeddings': existing_embeddings, 'documents': existing_texts, 'ids': existing_ids}
|
args = {'embeddings': existing_embeddings, 'documents': existing_texts, 'ids': existing_ids}
|
||||||
if metadatas is not None:
|
if metadatas is not None:
|
||||||
args['metadatas'] = existing_metas
|
args['metadatas'] = existing_metas
|
||||||
self.collection.add(**args)
|
self.collection.add(**args)
|
||||||
|
|
||||||
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
|
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
|
||||||
if non_existing_texts:
|
if non_existing_texts:
|
||||||
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist()
|
non_existing_embeddings = embedder(non_existing_texts)
|
||||||
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
|
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
|
||||||
self.embeddings_cache[text] = embedding
|
self.embeddings_cache[text] = embedding
|
||||||
|
|
||||||
logger.info(f'Adding {len(non_existing_embeddings)} new embeddings.')
|
logger.info(f'Adding {len(non_existing_embeddings)} new embeddings.')
|
||||||
args = {'embeddings': non_existing_embeddings, 'documents': non_existing_texts, 'ids': non_existing_ids}
|
args = {'embeddings': non_existing_embeddings, 'documents': non_existing_texts, 'ids': non_existing_ids}
|
||||||
if metadatas is not None:
|
if metadatas is not None:
|
||||||
args['metadatas'] = non_existing_metas
|
args['metadatas'] = non_existing_metas
|
||||||
self.collection.add(**args)
|
self.collection.add(**args)
|
||||||
|
|
||||||
@ -145,7 +129,6 @@ class ChromaCollector(Collecter):
|
|||||||
self.id_to_info.update(new_info)
|
self.id_to_info.update(new_info)
|
||||||
self.ids.extend(new_ids)
|
self.ids.extend(new_ids)
|
||||||
|
|
||||||
|
|
||||||
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
|
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
|
||||||
existing_texts, non_existing_texts = [], []
|
existing_texts, non_existing_texts = [], []
|
||||||
existing_embeddings = []
|
existing_embeddings = []
|
||||||
@ -169,7 +152,6 @@ class ChromaCollector(Collecter):
|
|||||||
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||||||
(non_existing_texts, non_existing_ids, non_existing_metas)
|
(non_existing_texts, non_existing_ids, non_existing_metas)
|
||||||
|
|
||||||
|
|
||||||
def _get_new_ids(self, num_new_ids: int):
|
def _get_new_ids(self, num_new_ids: int):
|
||||||
if self.ids:
|
if self.ids:
|
||||||
max_existing_id = max(int(id_) for id_ in self.ids)
|
max_existing_id = max(int(id_) for id_ in self.ids)
|
||||||
@ -178,7 +160,6 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
|
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
|
||||||
|
|
||||||
|
|
||||||
def _find_min_max_start_index(self):
|
def _find_min_max_start_index(self):
|
||||||
max_index, min_index = 0, float('inf')
|
max_index, min_index = 0, float('inf')
|
||||||
for _, val in self.id_to_info.items():
|
for _, val in self.id_to_info.items():
|
||||||
@ -188,34 +169,34 @@ class ChromaCollector(Collecter):
|
|||||||
min_index = val['start_index']
|
min_index = val['start_index']
|
||||||
return min_index, max_index
|
return min_index, max_index
|
||||||
|
|
||||||
|
# NB: Does not make sense to weigh excerpts from different documents.
|
||||||
# NB: Does not make sense to weigh excerpts from different documents.
|
|
||||||
# But let's say that's the user's problem. Perfect world scenario:
|
# But let's say that's the user's problem. Perfect world scenario:
|
||||||
# Apply time weighing to different documents. For each document, then, add
|
# Apply time weighing to different documents. For each document, then, add
|
||||||
# separate time weighing.
|
# separate time weighing.
|
||||||
|
|
||||||
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
|
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
|
||||||
sigmoid = lambda x: 1 / (1 + np.exp(-x))
|
def sigmoid(x):
|
||||||
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
|
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
|
||||||
|
|
||||||
# Scale to [0,time_power] and shift it up to [1-time_power, 1]
|
# Scale to [0,time_power] and shift it up to [1-time_power, 1]
|
||||||
weights = weights - min(weights)
|
weights = weights - min(weights)
|
||||||
weights = weights * (time_power / max(weights))
|
weights = weights * (time_power / max(weights))
|
||||||
weights = weights + (1 - time_power)
|
weights = weights + (1 - time_power)
|
||||||
|
|
||||||
# Reverse the weights
|
# Reverse the weights
|
||||||
weights = weights[::-1]
|
weights = weights[::-1]
|
||||||
|
|
||||||
for info in infos:
|
for info in infos:
|
||||||
index = info.start_index
|
index = info.start_index
|
||||||
info.distance *= weights[index]
|
info.distance *= weights[index]
|
||||||
|
|
||||||
|
|
||||||
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
|
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
|
||||||
# Ensure there are infos to filter
|
# Ensure there are infos to filter
|
||||||
if not infos:
|
if not infos:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Find info with minimum distance
|
# Find info with minimum distance
|
||||||
min_info = min(infos, key=lambda x: x.distance)
|
min_info = min(infos, key=lambda x: x.distance)
|
||||||
|
|
||||||
@ -231,7 +212,6 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
return filtered_infos
|
return filtered_infos
|
||||||
|
|
||||||
|
|
||||||
def _merge_infos(self, infos: list[Info]):
|
def _merge_infos(self, infos: list[Info]):
|
||||||
merged_infos = []
|
merged_infos = []
|
||||||
current_info = infos[0]
|
current_info = infos[0]
|
||||||
@ -247,8 +227,8 @@ class ChromaCollector(Collecter):
|
|||||||
merged_infos.append(current_info)
|
merged_infos.append(current_info)
|
||||||
return merged_infos
|
return merged_infos
|
||||||
|
|
||||||
|
|
||||||
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
|
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
|
||||||
|
|
||||||
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
||||||
n_results = min(len(self.ids), n_results)
|
n_results = min(len(self.ids), n_results)
|
||||||
if n_results == 0:
|
if n_results == 0:
|
||||||
@ -262,11 +242,11 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
for search_string in search_strings:
|
for search_string in search_strings:
|
||||||
result = self.collection.query(query_texts=search_string, n_results=math.ceil(n_results / len(search_strings)), include=['distances'])
|
result = self.collection.query(query_texts=search_string, n_results=math.ceil(n_results / len(search_strings)), include=['distances'])
|
||||||
curr_infos = [Info(start_index=self.id_to_info[id]['start_index'],
|
curr_infos = [Info(start_index=self.id_to_info[id]['start_index'],
|
||||||
text_with_context=self.id_to_info[id]['text_with_context'],
|
text_with_context=self.id_to_info[id]['text_with_context'],
|
||||||
distance=distance, id=id)
|
distance=distance, id=id)
|
||||||
for id, distance in zip(result['ids'][0], result['distances'][0])]
|
for id, distance in zip(result['ids'][0], result['distances'][0])]
|
||||||
|
|
||||||
self._apply_sigmoid_time_weighing(infos=curr_infos, document_len=max_start_index - min_start_index + 1, time_steepness=parameters.get_time_steepness(), time_power=parameters.get_time_power())
|
self._apply_sigmoid_time_weighing(infos=curr_infos, document_len=max_start_index - min_start_index + 1, time_steepness=parameters.get_time_steepness(), time_power=parameters.get_time_power())
|
||||||
curr_infos = self._filter_outliers_by_median_distance(curr_infos, parameters.get_significant_level())
|
curr_infos = self._filter_outliers_by_median_distance(curr_infos, parameters.get_significant_level())
|
||||||
infos.extend(curr_infos)
|
infos.extend(curr_infos)
|
||||||
@ -279,23 +259,23 @@ class ChromaCollector(Collecter):
|
|||||||
distances = [inf.distance for inf in infos]
|
distances = [inf.distance for inf in infos]
|
||||||
|
|
||||||
return texts_with_context, ids, distances
|
return texts_with_context, ids, distances
|
||||||
|
|
||||||
|
|
||||||
# Get chunks by similarity
|
# Get chunks by similarity
|
||||||
|
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
|
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
|
||||||
# Get ids by similarity
|
# Get ids by similarity
|
||||||
|
|
||||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
# Cutoff token count
|
# Cutoff token count
|
||||||
|
|
||||||
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
|
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
|
||||||
# TODO: Move to caller; We add delimiters there which might go over the limit.
|
# TODO: Move to caller; We add delimiters there which might go over the limit.
|
||||||
current_token_count = 0
|
current_token_count = 0
|
||||||
@ -308,7 +288,7 @@ class ChromaCollector(Collecter):
|
|||||||
# If adding this document would exceed the max token count,
|
# If adding this document would exceed the max token count,
|
||||||
# truncate the document to fit within the limit.
|
# truncate the document to fit within the limit.
|
||||||
remaining_tokens = max_token_count - current_token_count
|
remaining_tokens = max_token_count - current_token_count
|
||||||
|
|
||||||
truncated_doc = decode(doc_tokens[:remaining_tokens], skip_special_tokens=True)
|
truncated_doc = decode(doc_tokens[:remaining_tokens], skip_special_tokens=True)
|
||||||
return_documents.append(truncated_doc)
|
return_documents.append(truncated_doc)
|
||||||
break
|
break
|
||||||
@ -317,29 +297,28 @@ class ChromaCollector(Collecter):
|
|||||||
current_token_count += doc_token_count
|
current_token_count += doc_token_count
|
||||||
|
|
||||||
return return_documents
|
return return_documents
|
||||||
|
|
||||||
|
|
||||||
# Get chunks by similarity and then sort by ids
|
# Get chunks by similarity and then sort by ids
|
||||||
|
|
||||||
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||||
sorted_docs = [x for _, x in sorted(zip(ids, documents))]
|
sorted_docs = [x for _, x in sorted(zip(ids, documents))]
|
||||||
|
|
||||||
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||||||
|
|
||||||
|
|
||||||
# Get chunks by similarity and then sort by distance (lowest distance is last).
|
# Get chunks by similarity and then sort by distance (lowest distance is last).
|
||||||
|
|
||||||
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
|
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
|
||||||
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
|
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
|
||||||
|
|
||||||
# If a document is truncated or competely skipped, it would be with high distance.
|
# If a document is truncated or competely skipped, it would be with high distance.
|
||||||
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||||||
return_documents.reverse() # highest -> lowest
|
return_documents.reverse() # highest -> lowest
|
||||||
|
|
||||||
return return_documents
|
return return_documents
|
||||||
|
|
||||||
|
|
||||||
def delete(self, ids_to_delete: list[str], where: dict):
|
def delete(self, ids_to_delete: list[str], where: dict):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
@ -354,23 +333,16 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
|
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
|
||||||
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.chroma_client.reset()
|
self.chroma_client.reset()
|
||||||
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed)
|
|
||||||
self.ids = []
|
self.ids = []
|
||||||
self.id_to_info = {}
|
self.chroma_client.delete_collection(name=self.name)
|
||||||
|
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||||
|
|
||||||
logger.info('Successfully cleared all records and reset chromaDB.')
|
logger.info('Successfully cleared all records and reset chromaDB.')
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformerEmbedder(Embedder):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
logger.debug('Creating Sentence Embedder...')
|
|
||||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
|
||||||
self.embed = self.model.encode
|
|
||||||
|
|
||||||
|
|
||||||
def make_collector():
|
def make_collector():
|
||||||
return ChromaCollector(SentenceTransformerEmbedder())
|
return ChromaCollector()
|
||||||
|
@ -11,32 +11,29 @@ This module contains utils for preprocessing the text before converting it to em
|
|||||||
* removing specific parts of speech (adverbs and interjections)
|
* removing specific parts of speech (adverbs and interjections)
|
||||||
- TextSummarizer extracts the most important sentences from a long string using text-ranking.
|
- TextSummarizer extracts the most important sentences from a long string using text-ranking.
|
||||||
"""
|
"""
|
||||||
import pytextrank
|
|
||||||
import string
|
|
||||||
import spacy
|
|
||||||
import math
|
import math
|
||||||
import nltk
|
|
||||||
import re
|
import re
|
||||||
|
import string
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
import spacy
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from nltk.stem import WordNetLemmatizer
|
from nltk.stem import WordNetLemmatizer
|
||||||
from num2words import num2words
|
from num2words import num2words
|
||||||
|
|
||||||
|
|
||||||
class TextPreprocessorBuilder:
|
class TextPreprocessorBuilder:
|
||||||
# Define class variables as None initially
|
# Define class variables as None initially
|
||||||
_stop_words = set(stopwords.words('english'))
|
_stop_words = set(stopwords.words('english'))
|
||||||
_lemmatizer = WordNetLemmatizer()
|
_lemmatizer = WordNetLemmatizer()
|
||||||
|
|
||||||
# Some of the functions are expensive. We cache the results.
|
# Some of the functions are expensive. We cache the results.
|
||||||
_lemmatizer_cache = {}
|
_lemmatizer_cache = {}
|
||||||
_pos_remove_cache = {}
|
_pos_remove_cache = {}
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, text: str):
|
def __init__(self, text: str):
|
||||||
self.text = text
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
def to_lower(self):
|
def to_lower(self):
|
||||||
# Match both words and non-word characters
|
# Match both words and non-word characters
|
||||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||||
@ -49,7 +46,6 @@ class TextPreprocessorBuilder:
|
|||||||
self.text = "".join(tokens)
|
self.text = "".join(tokens)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def num_to_word(self, min_len: int = 1):
|
def num_to_word(self, min_len: int = 1):
|
||||||
# Match both words and non-word characters
|
# Match both words and non-word characters
|
||||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||||
@ -58,11 +54,10 @@ class TextPreprocessorBuilder:
|
|||||||
if token.isdigit() and len(token) >= min_len:
|
if token.isdigit() and len(token) >= min_len:
|
||||||
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
||||||
# 740700 will become "seven hundred and forty thousand seven hundred".
|
# 740700 will become "seven hundred and forty thousand seven hundred".
|
||||||
tokens[i] = num2words(int(token)).replace(",","") # Remove commas from num2words.
|
tokens[i] = num2words(int(token)).replace(",", "") # Remove commas from num2words.
|
||||||
self.text = "".join(tokens)
|
self.text = "".join(tokens)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def num_to_char_long(self, min_len: int = 1):
|
def num_to_char_long(self, min_len: int = 1):
|
||||||
# Match both words and non-word characters
|
# Match both words and non-word characters
|
||||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||||
@ -71,11 +66,13 @@ class TextPreprocessorBuilder:
|
|||||||
if token.isdigit() and len(token) >= min_len:
|
if token.isdigit() and len(token) >= min_len:
|
||||||
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
||||||
# 740700 will become HHHHHHEEEEEAAAAHHHAAA
|
# 740700 will become HHHHHHEEEEEAAAAHHHAAA
|
||||||
convert_token = lambda token: ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
|
def convert_token(token):
|
||||||
|
return ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
|
||||||
|
|
||||||
tokens[i] = convert_token(tokens[i])
|
tokens[i] = convert_token(tokens[i])
|
||||||
self.text = "".join(tokens)
|
self.text = "".join(tokens)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def num_to_char(self, min_len: int = 1):
|
def num_to_char(self, min_len: int = 1):
|
||||||
# Match both words and non-word characters
|
# Match both words and non-word characters
|
||||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||||
@ -87,15 +84,15 @@ class TextPreprocessorBuilder:
|
|||||||
tokens[i] = ''.join(chr(int(digit) + 65) for digit in token)
|
tokens[i] = ''.join(chr(int(digit) + 65) for digit in token)
|
||||||
self.text = "".join(tokens)
|
self.text = "".join(tokens)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def merge_spaces(self):
|
def merge_spaces(self):
|
||||||
self.text = re.sub(' +', ' ', self.text)
|
self.text = re.sub(' +', ' ', self.text)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def strip(self):
|
def strip(self):
|
||||||
self.text = self.text.strip()
|
self.text = self.text.strip()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def remove_punctuation(self):
|
def remove_punctuation(self):
|
||||||
self.text = self.text.translate(str.maketrans('', '', string.punctuation))
|
self.text = self.text.translate(str.maketrans('', '', string.punctuation))
|
||||||
return self
|
return self
|
||||||
@ -103,7 +100,7 @@ class TextPreprocessorBuilder:
|
|||||||
def remove_stopwords(self):
|
def remove_stopwords(self):
|
||||||
self.text = "".join([word for word in re.findall(r'\b\w+\b|\W+', self.text) if word not in TextPreprocessorBuilder._stop_words])
|
self.text = "".join([word for word in re.findall(r'\b\w+\b|\W+', self.text) if word not in TextPreprocessorBuilder._stop_words])
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def remove_specific_pos(self):
|
def remove_specific_pos(self):
|
||||||
"""
|
"""
|
||||||
In the English language, adverbs and interjections rarely provide meaningul information.
|
In the English language, adverbs and interjections rarely provide meaningul information.
|
||||||
@ -140,7 +137,7 @@ class TextPreprocessorBuilder:
|
|||||||
if processed_text:
|
if processed_text:
|
||||||
self.text = processed_text
|
self.text = processed_text
|
||||||
return self
|
return self
|
||||||
|
|
||||||
new_text = "".join([TextPreprocessorBuilder._lemmatizer.lemmatize(word) for word in re.findall(r'\b\w+\b|\W+', self.text)])
|
new_text = "".join([TextPreprocessorBuilder._lemmatizer.lemmatize(word) for word in re.findall(r'\b\w+\b|\W+', self.text)])
|
||||||
TextPreprocessorBuilder._lemmatizer_cache[self.text] = new_text
|
TextPreprocessorBuilder._lemmatizer_cache[self.text] = new_text
|
||||||
self.text = new_text
|
self.text = new_text
|
||||||
@ -150,6 +147,7 @@ class TextPreprocessorBuilder:
|
|||||||
def build(self):
|
def build(self):
|
||||||
return self.text
|
return self.text
|
||||||
|
|
||||||
|
|
||||||
class TextSummarizer:
|
class TextSummarizer:
|
||||||
_nlp_pipeline = None
|
_nlp_pipeline = None
|
||||||
_cache = {}
|
_cache = {}
|
||||||
@ -165,7 +163,7 @@ class TextSummarizer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def process_long_text(text: str, min_num_sent: int) -> list[str]:
|
def process_long_text(text: str, min_num_sent: int) -> list[str]:
|
||||||
"""
|
"""
|
||||||
This function applies a text summarization process on a given text string, extracting
|
This function applies a text summarization process on a given text string, extracting
|
||||||
the most important sentences based on the principle that 20% of the content is responsible
|
the most important sentences based on the principle that 20% of the content is responsible
|
||||||
for 80% of the meaning (the Pareto Principle).
|
for 80% of the meaning (the Pareto Principle).
|
||||||
|
|
||||||
@ -193,7 +191,7 @@ class TextSummarizer:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
result = [text]
|
result = [text]
|
||||||
|
|
||||||
# Store the result in cache before returning it
|
# Store the result in cache before returning it
|
||||||
TextSummarizer._cache[cache_key] = result
|
TextSummarizer._cache[cache_key] = result
|
||||||
return result
|
return result
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
This module is responsible for processing the corpus and feeding it into chromaDB. It will receive a corpus of text.
|
This module is responsible for processing the corpus and feeding it into chromaDB. It will receive a corpus of text.
|
||||||
It will then split it into chunks of specified length. For each of those chunks, it will append surrounding context.
|
It will then split it into chunks of specified length. For each of those chunks, it will append surrounding context.
|
||||||
It will only include full words.
|
It will only include full words.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
|
||||||
import bisect
|
import bisect
|
||||||
|
import re
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
|
||||||
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
|
|
||||||
from .chromadb import ChromaCollector
|
from .chromadb import ChromaCollector
|
||||||
|
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
|
||||||
|
|
||||||
|
|
||||||
def preprocess_text_no_summary(text) -> str:
|
def preprocess_text_no_summary(text) -> str:
|
||||||
builder = TextPreprocessorBuilder(text)
|
builder = TextPreprocessorBuilder(text)
|
||||||
@ -42,7 +43,7 @@ def preprocess_text_no_summary(text) -> str:
|
|||||||
builder.num_to_char(parameters.get_min_num_length())
|
builder.num_to_char(parameters.get_min_num_length())
|
||||||
elif parameters.get_num_conversion_strategy() == parameters.NUM_TO_CHAR_LONG_METHOD:
|
elif parameters.get_num_conversion_strategy() == parameters.NUM_TO_CHAR_LONG_METHOD:
|
||||||
builder.num_to_char_long(parameters.get_min_num_length())
|
builder.num_to_char_long(parameters.get_min_num_length())
|
||||||
|
|
||||||
return builder.build()
|
return builder.build()
|
||||||
|
|
||||||
|
|
||||||
@ -53,10 +54,10 @@ def preprocess_text(text) -> list[str]:
|
|||||||
|
|
||||||
def _create_chunks_with_context(corpus, chunk_len, context_left, context_right):
|
def _create_chunks_with_context(corpus, chunk_len, context_left, context_right):
|
||||||
"""
|
"""
|
||||||
This function takes a corpus of text and splits it into chunks of a specified length,
|
This function takes a corpus of text and splits it into chunks of a specified length,
|
||||||
then adds a specified amount of context to each chunk. The context is added by first
|
then adds a specified amount of context to each chunk. The context is added by first
|
||||||
going backwards from the start of the chunk and then going forwards from the end of the
|
going backwards from the start of the chunk and then going forwards from the end of the
|
||||||
chunk, ensuring that the context includes only whole words and that the total context length
|
chunk, ensuring that the context includes only whole words and that the total context length
|
||||||
does not exceed the specified limit. This function uses binary search for efficiency.
|
does not exceed the specified limit. This function uses binary search for efficiency.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -102,7 +103,7 @@ def _create_chunks_with_context(corpus, chunk_len, context_left, context_right):
|
|||||||
# Combine all the words in the context range (before, chunk, and after)
|
# Combine all the words in the context range (before, chunk, and after)
|
||||||
chunk_with_context = ''.join(words[context_start_index:context_end_index])
|
chunk_with_context = ''.join(words[context_start_index:context_end_index])
|
||||||
chunks_with_context.append(chunk_with_context)
|
chunks_with_context.append(chunk_with_context)
|
||||||
|
|
||||||
# Determine the start index of the chunk with context
|
# Determine the start index of the chunk with context
|
||||||
chunk_with_context_start_index = word_start_indices[context_start_index]
|
chunk_with_context_start_index = word_start_indices[context_start_index]
|
||||||
chunk_with_context_start_indices.append(chunk_with_context_start_index)
|
chunk_with_context_start_indices.append(chunk_with_context_start_index)
|
||||||
@ -125,9 +126,9 @@ def _clear_chunks(data_chunks, data_chunks_with_context, data_chunk_starting_ind
|
|||||||
seen_chunk_start = seen_chunks.get(chunk)
|
seen_chunk_start = seen_chunks.get(chunk)
|
||||||
if seen_chunk_start:
|
if seen_chunk_start:
|
||||||
# If we've already seen this exact chunk, and the context around it it very close to the seen chunk, then skip it.
|
# If we've already seen this exact chunk, and the context around it it very close to the seen chunk, then skip it.
|
||||||
if abs(seen_chunk_start-index) < parameters.get_delta_start():
|
if abs(seen_chunk_start - index) < parameters.get_delta_start():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
distinct_data_chunks.append(chunk)
|
distinct_data_chunks.append(chunk)
|
||||||
distinct_data_chunks_with_context.append(context)
|
distinct_data_chunks_with_context.append(context)
|
||||||
distinct_data_chunk_starting_indices.append(index)
|
distinct_data_chunk_starting_indices.append(index)
|
||||||
@ -206,4 +207,4 @@ def process_and_add_to_collector(corpus: str, collector: ChromaCollector, clear_
|
|||||||
|
|
||||||
if clear_collector_before_adding:
|
if clear_collector_before_adding:
|
||||||
collector.clear()
|
collector.clear()
|
||||||
collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata]*len(data_chunks) if metadata is not None else None)
|
collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata] * len(data_chunks) if metadata is not None else None)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import requests
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
@ -9,6 +9,7 @@ import extensions.superboogav2.parameters as parameters
|
|||||||
from .data_processor import process_and_add_to_collector
|
from .data_processor import process_and_add_to_collector
|
||||||
from .utils import create_metadata_source
|
from .utils import create_metadata_source
|
||||||
|
|
||||||
|
|
||||||
def _download_single(url):
|
def _download_single(url):
|
||||||
response = requests.get(url, timeout=5)
|
response = requests.get(url, timeout=5)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
@ -62,4 +63,4 @@ def feed_url_into_collector(urls, collector):
|
|||||||
text = '\n'.join([s.strip() for s in strings])
|
text = '\n'.join([s.strip() for s in strings])
|
||||||
all_text += text
|
all_text += text
|
||||||
|
|
||||||
process_and_add_to_collector(all_text, collector, False, create_metadata_source('url-download'))
|
process_and_add_to_collector(all_text, collector, False, create_metadata_source('url-download'))
|
||||||
|
@ -4,13 +4,12 @@ This module is responsible for handling and modifying the notebook text.
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from extensions.superboogav2.utils import create_context_text
|
from extensions.superboogav2.utils import create_context_text
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .data_processor import preprocess_text
|
from .data_processor import preprocess_text
|
||||||
|
|
||||||
|
|
||||||
def _remove_special_tokens(string):
|
def _remove_special_tokens(string):
|
||||||
pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
|
pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
|
||||||
return re.sub(pattern, '', string)
|
return re.sub(pattern, '', string)
|
||||||
@ -37,4 +36,4 @@ def input_modifier_internal(string, collector, is_chat):
|
|||||||
# Make the injection
|
# Make the injection
|
||||||
string = string.replace('<|injection-point|>', create_context_text(results))
|
string = string.replace('<|injection-point|>', create_context_text(results))
|
||||||
|
|
||||||
return _remove_special_tokens(string)
|
return _remove_special_tokens(string)
|
||||||
|
@ -3,22 +3,24 @@ This module implements a hyperparameter optimization routine for the embedding a
|
|||||||
|
|
||||||
Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found.
|
Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found.
|
||||||
"""
|
"""
|
||||||
import re
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import optuna
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import optuna
|
||||||
import hashlib
|
|
||||||
logging.getLogger('optuna').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
logging.getLogger('optuna').setLevel(logging.WARNING)
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .benchmark import benchmark
|
from .benchmark import benchmark
|
||||||
from .parameters import Parameters
|
from .parameters import Parameters
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
|
|
||||||
# Format the parameters into markdown format.
|
# Format the parameters into markdown format.
|
||||||
@ -28,7 +30,7 @@ def _markdown_hyperparams():
|
|||||||
# Escape any markdown syntax
|
# Escape any markdown syntax
|
||||||
param_name = re.sub(r"([_*\[\]()~`>#+-.!])", r"\\\1", param_name)
|
param_name = re.sub(r"([_*\[\]()~`>#+-.!])", r"\\\1", param_name)
|
||||||
param_value_default = re.sub(r"([_*\[\]()~`>#+-.!])", r"\\\1", str(param_value['default'])) if param_value['default'] else ' '
|
param_value_default = re.sub(r"([_*\[\]()~`>#+-.!])", r"\\\1", str(param_value['default'])) if param_value['default'] else ' '
|
||||||
|
|
||||||
res.append('* {}: **{}**'.format(param_name, param_value_default))
|
res.append('* {}: **{}**'.format(param_name, param_value_default))
|
||||||
|
|
||||||
return '\n'.join(res)
|
return '\n'.join(res)
|
||||||
@ -49,13 +51,13 @@ def _convert_np_types(params):
|
|||||||
# Set the default values for the hyperparameters.
|
# Set the default values for the hyperparameters.
|
||||||
def _set_hyperparameters(params):
|
def _set_hyperparameters(params):
|
||||||
for param_name, param_value in params.items():
|
for param_name, param_value in params.items():
|
||||||
if param_name in Parameters.getInstance().hyperparameters:
|
if param_name in Parameters.getInstance().hyperparameters:
|
||||||
Parameters.getInstance().hyperparameters[param_name]['default'] = param_value
|
Parameters.getInstance().hyperparameters[param_name]['default'] = param_value
|
||||||
|
|
||||||
|
|
||||||
# Check if the parameter is for optimization.
|
# Check if the parameter is for optimization.
|
||||||
def _is_optimization_param(val):
|
def _is_optimization_param(val):
|
||||||
is_opt = val.get('should_optimize', False) # Either does not exist or is false
|
is_opt = val.get('should_optimize', False) # Either does not exist or is false
|
||||||
return is_opt
|
return is_opt
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +69,7 @@ def _get_params_hash(params):
|
|||||||
|
|
||||||
def optimize(collector, progress=gr.Progress()):
|
def optimize(collector, progress=gr.Progress()):
|
||||||
# Inform the user that something is happening.
|
# Inform the user that something is happening.
|
||||||
progress(0, desc=f'Setting Up...')
|
progress(0, desc='Setting Up...')
|
||||||
|
|
||||||
# Track the current step
|
# Track the current step
|
||||||
current_step = 0
|
current_step = 0
|
||||||
@ -132,4 +134,4 @@ def optimize(collector, progress=gr.Progress()):
|
|||||||
with open('best_params.json', 'w') as fp:
|
with open('best_params.json', 'w') as fp:
|
||||||
json.dump(_convert_np_types(best_params), fp, indent=4)
|
json.dump(_convert_np_types(best_params), fp, indent=4)
|
||||||
|
|
||||||
return str_result
|
return str_result
|
||||||
|
@ -1,18 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
This module provides a singleton class `Parameters` that is used to manage all hyperparameters for the embedding application.
|
This module provides a singleton class `Parameters` that is used to manage all hyperparameters for the embedding application.
|
||||||
It expects a JSON file in `extensions/superboogav2/config.json`.
|
It expects a JSON file in `extensions/superboogav2/config.json`.
|
||||||
|
|
||||||
Each element in the JSON must have a `default` value which will be used for the current run. Elements can have `categories`.
|
Each element in the JSON must have a `default` value which will be used for the current run. Elements can have `categories`.
|
||||||
These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`,
|
These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`,
|
||||||
then the optimizer will only ever use the default value.
|
then the optimizer will only ever use the default value.
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
NUM_TO_WORD_METHOD = 'Number to Word'
|
NUM_TO_WORD_METHOD = 'Number to Word'
|
||||||
NUM_TO_CHAR_METHOD = 'Number to Char'
|
NUM_TO_CHAR_METHOD = 'Number to Char'
|
||||||
NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char'
|
NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char'
|
||||||
@ -366,4 +364,4 @@ def set_api_port(value: int):
|
|||||||
|
|
||||||
|
|
||||||
def set_api_on(value: bool):
|
def set_api_on(value: bool):
|
||||||
Parameters.getInstance().hyperparameters['api_on']['default'] = value
|
Parameters.getInstance().hyperparameters['api_on']['default'] = value
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
beautifulsoup4==4.12.2
|
beautifulsoup4==4.12.2
|
||||||
chromadb==0.3.18
|
chromadb==0.4.24
|
||||||
lxml
|
lxml
|
||||||
optuna
|
optuna
|
||||||
pandas==2.0.3
|
pandas==2.0.3
|
||||||
@ -7,4 +7,4 @@ posthog==2.4.2
|
|||||||
sentence_transformers==2.2.2
|
sentence_transformers==2.2.2
|
||||||
spacy
|
spacy
|
||||||
pytextrank
|
pytextrank
|
||||||
num2words
|
num2words
|
||||||
|
@ -7,28 +7,29 @@ from pathlib import Path
|
|||||||
# Point to where nltk will find the required data.
|
# Point to where nltk will find the required data.
|
||||||
os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve())
|
os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve())
|
||||||
|
|
||||||
import textwrap
|
|
||||||
import codecs
|
import codecs
|
||||||
|
import textwrap
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .utils import create_metadata_source
|
|
||||||
from .chromadb import make_collector
|
|
||||||
from .download_urls import feed_url_into_collector
|
|
||||||
from .data_processor import process_and_add_to_collector
|
|
||||||
from .benchmark import benchmark
|
|
||||||
from .optimize import optimize
|
|
||||||
from .notebook_handler import input_modifier_internal
|
|
||||||
from .chat_handler import custom_generate_chat_prompt_internal
|
|
||||||
from .api import APIManager
|
from .api import APIManager
|
||||||
|
from .benchmark import benchmark
|
||||||
|
from .chat_handler import custom_generate_chat_prompt_internal
|
||||||
|
from .chromadb import make_collector
|
||||||
|
from .data_processor import process_and_add_to_collector
|
||||||
|
from .download_urls import feed_url_into_collector
|
||||||
|
from .notebook_handler import input_modifier_internal
|
||||||
|
from .optimize import optimize
|
||||||
|
from .utils import create_metadata_source
|
||||||
|
|
||||||
collector = None
|
collector = None
|
||||||
api_manager = None
|
api_manager = None
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
global collector
|
global collector
|
||||||
global api_manager
|
global api_manager
|
||||||
@ -38,6 +39,7 @@ def setup():
|
|||||||
if parameters.get_api_on():
|
if parameters.get_api_on():
|
||||||
api_manager.start_server(parameters.get_api_port())
|
api_manager.start_server(parameters.get_api_port())
|
||||||
|
|
||||||
|
|
||||||
def _feed_data_into_collector(corpus):
|
def _feed_data_into_collector(corpus):
|
||||||
yield '### Processing data...'
|
yield '### Processing data...'
|
||||||
process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text'))
|
process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text'))
|
||||||
@ -87,7 +89,7 @@ def _get_optimizable_settings() -> list:
|
|||||||
preprocess_pipeline.append('Merge Spaces')
|
preprocess_pipeline.append('Merge Spaces')
|
||||||
if parameters.should_strip():
|
if parameters.should_strip():
|
||||||
preprocess_pipeline.append('Strip Edges')
|
preprocess_pipeline.append('Strip Edges')
|
||||||
|
|
||||||
return [
|
return [
|
||||||
parameters.get_time_power(),
|
parameters.get_time_power(),
|
||||||
parameters.get_time_steepness(),
|
parameters.get_time_steepness(),
|
||||||
@ -104,8 +106,8 @@ def _get_optimizable_settings() -> list:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||||
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup):
|
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup):
|
||||||
logger.debug('Applying settings.')
|
logger.debug('Applying settings.')
|
||||||
|
|
||||||
@ -240,7 +242,7 @@ def ui():
|
|||||||
with gr.Tab("File input"):
|
with gr.Tab("File input"):
|
||||||
file_input = gr.File(label='Input file', type='binary')
|
file_input = gr.File(label='Input file', type='binary')
|
||||||
update_file = gr.Button('Load data')
|
update_file = gr.Button('Load data')
|
||||||
|
|
||||||
with gr.Tab("Settings"):
|
with gr.Tab("Settings"):
|
||||||
with gr.Accordion("Processing settings", open=True):
|
with gr.Accordion("Processing settings", open=True):
|
||||||
chunk_len = gr.Textbox(value=parameters.get_chunk_len(), label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
|
chunk_len = gr.Textbox(value=parameters.get_chunk_len(), label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
|
||||||
@ -305,19 +307,16 @@ def ui():
|
|||||||
optimize_button = gr.Button('Optimize')
|
optimize_button = gr.Button('Optimize')
|
||||||
optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True)
|
optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True)
|
||||||
|
|
||||||
|
|
||||||
clear_button = gr.Button('❌ Clear Data')
|
clear_button = gr.Button('❌ Clear Data')
|
||||||
|
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
last_updated = gr.Markdown()
|
last_updated = gr.Markdown()
|
||||||
|
|
||||||
all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||||
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup]
|
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup]
|
||||||
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||||
preprocess_pipeline, chunk_count, context_len, chunk_len]
|
preprocess_pipeline, chunk_count, context_len, chunk_len]
|
||||||
|
|
||||||
|
|
||||||
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
|
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
|
||||||
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
|
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
|
||||||
@ -326,7 +325,6 @@ def ui():
|
|||||||
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
|
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
|
||||||
clear_button.click(_clear_data, [], last_updated, show_progress=False)
|
clear_button.click(_clear_data, [], last_updated, show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
@ -352,4 +350,4 @@ def ui():
|
|||||||
chunk_regex.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
chunk_regex.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
chunk_len.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
chunk_len.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
threads.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
threads.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
strong_cleanup.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
strong_cleanup.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
|
@ -4,6 +4,7 @@ This module contains common functions across multiple other modules.
|
|||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
|
||||||
|
|
||||||
# Create the context using the prefix + data_separator + postfix from parameters.
|
# Create the context using the prefix + data_separator + postfix from parameters.
|
||||||
def create_context_text(results):
|
def create_context_text(results):
|
||||||
context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix()
|
context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix()
|
||||||
@ -13,4 +14,4 @@ def create_context_text(results):
|
|||||||
|
|
||||||
# Create metadata with the specified source
|
# Create metadata with the specified source
|
||||||
def create_metadata_source(source: str):
|
def create_metadata_source(source: str):
|
||||||
return {'source': source}
|
return {'source': source}
|
||||||
|
Loading…
Reference in New Issue
Block a user