mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-26 12:22:08 +01:00
Make superbooga & superboogav2 functional again (#5656)
This commit is contained in:
parent
bae14c8f13
commit
2681f6f640
@ -1,43 +1,24 @@
|
||||
import random
|
||||
|
||||
import chromadb
|
||||
import posthog
|
||||
import torch
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from modules.logging_colors import logger
|
||||
|
||||
logger.info('Intercepting all calls to posthog :)')
|
||||
# Intercept calls to posthog
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class Collecter():
|
||||
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||
|
||||
|
||||
class ChromaCollector():
|
||||
def __init__(self):
|
||||
pass
|
||||
name = ''.join(random.choice('ab') for _ in range(10))
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
pass
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
|
||||
class Embedder():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def embed(self, text: str) -> list[torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class ChromaCollector(Collecter):
|
||||
def __init__(self, embedder: Embedder):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
self.embedder = embedder
|
||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
|
||||
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||
self.ids = []
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
@ -102,24 +83,15 @@ class ChromaCollector(Collecter):
|
||||
return sorted(ids)
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
self.ids = []
|
||||
|
||||
|
||||
class SentenceTransformerEmbedder(Embedder):
|
||||
def __init__(self) -> None:
|
||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||||
self.embed = self.model.encode
|
||||
self.chroma_client.delete_collection(name=self.name)
|
||||
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||
|
||||
|
||||
def make_collector():
|
||||
global embedder
|
||||
return ChromaCollector(embedder)
|
||||
return ChromaCollector()
|
||||
|
||||
|
||||
def add_chunks_to_collector(chunks, collector):
|
||||
collector.clear()
|
||||
collector.add(chunks)
|
||||
|
||||
|
||||
embedder = SentenceTransformerEmbedder()
|
||||
|
@ -1,5 +1,5 @@
|
||||
beautifulsoup4==4.12.2
|
||||
chromadb==0.3.18
|
||||
chromadb==0.4.24
|
||||
pandas==2.0.3
|
||||
posthog==2.4.2
|
||||
sentence_transformers==2.2.2
|
||||
|
@ -12,17 +12,16 @@ This module is responsible for the VectorDB API. It currently supports:
|
||||
|
||||
import json
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
from threading import Thread
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
from .chromadb import ChromaCollector
|
||||
from .data_processor import process_and_add_to_collector
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
|
||||
class CustomThreadingHTTPServer(ThreadingHTTPServer):
|
||||
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
|
||||
@ -38,7 +37,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.collector = collector
|
||||
super().__init__(request, client_address, server)
|
||||
|
||||
|
||||
def _send_412_error(self, message):
|
||||
self.send_response(412)
|
||||
self.send_header("Content-type", "application/json")
|
||||
@ -46,7 +44,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
response = json.dumps({"error": message})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
|
||||
def _send_404_error(self):
|
||||
self.send_response(404)
|
||||
self.send_header("Content-type", "application/json")
|
||||
@ -54,14 +51,12 @@ class Handler(BaseHTTPRequestHandler):
|
||||
response = json.dumps({"error": "Resource not found"})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
|
||||
def _send_400_error(self, error_message: str):
|
||||
self.send_response(400)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
response = json.dumps({"error": error_message})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
|
||||
def _send_200_response(self, message: str):
|
||||
self.send_response(200)
|
||||
@ -75,24 +70,21 @@ class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
|
||||
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
|
||||
if sort_param == parameters.SORT_DISTANCE:
|
||||
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
||||
elif sort_param == parameters.SORT_ID:
|
||||
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
|
||||
else: # Default is dist
|
||||
else: # Default is dist
|
||||
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
||||
|
||||
|
||||
return {
|
||||
"results": results
|
||||
}
|
||||
|
||||
|
||||
def do_GET(self):
|
||||
self._send_404_error()
|
||||
|
||||
|
||||
def do_POST(self):
|
||||
try:
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
@ -107,7 +99,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
if corpus is None:
|
||||
self._send_412_error("Missing parameter 'corpus'")
|
||||
return
|
||||
|
||||
|
||||
clear_before_adding = body.get('clear_before_adding', False)
|
||||
metadata = body.get('metadata')
|
||||
process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata)
|
||||
@ -118,7 +110,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
if corpus is None:
|
||||
self._send_412_error("Missing parameter 'metadata'")
|
||||
return
|
||||
|
||||
|
||||
self.collector.delete(ids_to_delete=None, where=metadata)
|
||||
self._send_200_response("Data successfully deleted")
|
||||
|
||||
@ -127,15 +119,15 @@ class Handler(BaseHTTPRequestHandler):
|
||||
if search_strings is None:
|
||||
self._send_412_error("Missing parameter 'search_strings'")
|
||||
return
|
||||
|
||||
|
||||
n_results = body.get('n_results')
|
||||
if n_results is None:
|
||||
n_results = parameters.get_chunk_count()
|
||||
|
||||
|
||||
max_token_count = body.get('max_token_count')
|
||||
if max_token_count is None:
|
||||
max_token_count = parameters.get_max_token_count()
|
||||
|
||||
|
||||
sort_param = query_params.get('sort', ['distance'])[0]
|
||||
|
||||
results = self._handle_get(search_strings, n_results, max_token_count, sort_param)
|
||||
@ -146,7 +138,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
except Exception as e:
|
||||
self._send_400_error(str(e))
|
||||
|
||||
|
||||
def do_DELETE(self):
|
||||
try:
|
||||
parsed_path = urlparse(self.path)
|
||||
@ -161,12 +152,10 @@ class Handler(BaseHTTPRequestHandler):
|
||||
except Exception as e:
|
||||
self._send_400_error(str(e))
|
||||
|
||||
|
||||
def do_OPTIONS(self):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
|
||||
|
||||
def end_headers(self):
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', '*')
|
||||
@ -197,11 +186,11 @@ class APIManager:
|
||||
|
||||
def stop_server(self):
|
||||
if self.server is not None:
|
||||
logger.info(f'Stopping chromaDB API.')
|
||||
logger.info('Stopping chromaDB API.')
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
self.server = None
|
||||
self.is_running = False
|
||||
|
||||
def is_server_running(self):
|
||||
return self.is_running
|
||||
return self.is_running
|
||||
|
@ -9,23 +9,23 @@ The benchmark function will return the score as an integer.
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .data_processor import process_and_add_to_collector, preprocess_text
|
||||
from .data_processor import preprocess_text, process_and_add_to_collector
|
||||
from .parameters import get_chunk_count, get_max_token_count
|
||||
from .utils import create_metadata_source
|
||||
|
||||
|
||||
def benchmark(config_path, collector):
|
||||
# Get the current system date
|
||||
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"benchmark_{sysdate}.txt"
|
||||
|
||||
|
||||
# Open the log file in append mode
|
||||
with open(filename, 'a') as log:
|
||||
with open(config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
total_points = 0
|
||||
max_points = 0
|
||||
|
||||
@ -45,7 +45,7 @@ def benchmark(config_path, collector):
|
||||
for question_group in item["questions"]:
|
||||
question_variants = question_group["question_variants"]
|
||||
criteria = question_group["criteria"]
|
||||
|
||||
|
||||
for q in question_variants:
|
||||
max_points += len(criteria)
|
||||
processed_text = preprocess_text(q)
|
||||
@ -54,7 +54,7 @@ def benchmark(config_path, collector):
|
||||
results = collector.get_sorted_by_dist(processed_text, n_results=get_chunk_count(), max_token_count=get_max_token_count())
|
||||
|
||||
points = 0
|
||||
|
||||
|
||||
for c in criteria:
|
||||
for p in results:
|
||||
if c in p:
|
||||
@ -69,4 +69,4 @@ def benchmark(config_path, collector):
|
||||
|
||||
print(f'##Total points:\n\n{total_points}/{max_points}', file=log)
|
||||
|
||||
return total_points, max_points
|
||||
return total_points, max_points
|
||||
|
@ -4,16 +4,17 @@ This module is responsible for modifying the chat prompt and history.
|
||||
import re
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
from extensions.superboogav2.utils import (
|
||||
create_context_text,
|
||||
create_metadata_source
|
||||
)
|
||||
from modules import chat, shared
|
||||
from modules.text_generation import get_encoded_length
|
||||
from modules.logging_colors import logger
|
||||
from modules.chat import load_character_memoized
|
||||
from extensions.superboogav2.utils import create_context_text, create_metadata_source
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import get_encoded_length
|
||||
|
||||
from .data_processor import process_and_add_to_collector
|
||||
from .chromadb import ChromaCollector
|
||||
|
||||
from .data_processor import process_and_add_to_collector
|
||||
|
||||
CHAT_METADATA = create_metadata_source('automatic-chat-insert')
|
||||
|
||||
@ -21,17 +22,17 @@ CHAT_METADATA = create_metadata_source('automatic-chat-insert')
|
||||
def _remove_tag_if_necessary(user_input: str):
|
||||
if not parameters.get_is_manual():
|
||||
return user_input
|
||||
|
||||
|
||||
return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input)
|
||||
|
||||
|
||||
def _should_query(input: str):
|
||||
if not parameters.get_is_manual():
|
||||
return True
|
||||
|
||||
|
||||
if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -69,7 +70,7 @@ def _concatinate_history(history: dict, state: dict):
|
||||
if len(exchange) >= 2:
|
||||
full_history_text += _format_single_exchange(bot_name, exchange[1])
|
||||
|
||||
return full_history_text[:-1] # Remove the last new line.
|
||||
return full_history_text[:-1] # Remove the last new line.
|
||||
|
||||
|
||||
def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
|
||||
@ -82,20 +83,20 @@ def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
|
||||
for i, messages in enumerate(reversed(history['internal'])):
|
||||
for j, message in enumerate(reversed(messages)):
|
||||
num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message))
|
||||
|
||||
|
||||
# TODO: This is an extremely naive solution. A more robust implementation must be made.
|
||||
if history_tokens + num_context_tokens <= max_len:
|
||||
# This message can be replaced
|
||||
replace_position = (i, j)
|
||||
|
||||
|
||||
history_tokens += num_message_tokens
|
||||
|
||||
|
||||
if replace_position is None:
|
||||
logger.warn("The provided context_text is too long to replace any message in the history.")
|
||||
else:
|
||||
# replace the message at replace_position with context_text
|
||||
i, j = replace_position
|
||||
history['internal'][-i-1][-j-1] = context_text
|
||||
history['internal'][-i - 1][-j - 1] = context_text
|
||||
|
||||
|
||||
def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
|
||||
@ -120,5 +121,5 @@ def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector
|
||||
user_input = create_context_text(results) + user_input
|
||||
elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT:
|
||||
_hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state)
|
||||
|
||||
|
||||
return chat.generate_chat_prompt(user_input, state, **kwargs)
|
||||
|
@ -1,42 +1,23 @@
|
||||
import threading
|
||||
import chromadb
|
||||
import posthog
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
import threading
|
||||
|
||||
import chromadb
|
||||
import numpy as np
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
import posthog
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import encode, decode
|
||||
from modules.text_generation import decode, encode
|
||||
|
||||
logger.debug('Intercepting all calls to posthog.')
|
||||
# Intercept calls to posthog
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class Collecter():
|
||||
def __init__(self):
|
||||
pass
|
||||
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||
|
||||
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]):
|
||||
pass
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
|
||||
class Embedder():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def embed(self, text: str) -> list[torch.Tensor]:
|
||||
pass
|
||||
|
||||
class Info:
|
||||
def __init__(self, start_index, text_with_context, distance, id):
|
||||
@ -58,7 +39,7 @@ class Info:
|
||||
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
|
||||
# Arithmetic mean
|
||||
return (self.distance + other_info.distance) / 2
|
||||
else: # Min is default
|
||||
else: # Min is default
|
||||
return min(self.distance, other_info.distance)
|
||||
|
||||
def merge_with(self, other_info):
|
||||
@ -66,7 +47,7 @@ class Info:
|
||||
s2 = other_info.text_with_context
|
||||
s1_start = self.start_index
|
||||
s2_start = other_info.start_index
|
||||
|
||||
|
||||
new_dist = self.calculate_distance(other_info)
|
||||
|
||||
if self.should_merge(s1, s2, s1_start, s2_start):
|
||||
@ -84,55 +65,58 @@ class Info:
|
||||
return Info(s2_start, s2 + s1[overlap:], new_dist, other_info.id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def should_merge(s1, s2, s1_start, s2_start):
|
||||
# Check if s1 and s2 are adjacent or overlapping
|
||||
s1_end = s1_start + len(s1)
|
||||
s2_end = s2_start + len(s2)
|
||||
|
||||
|
||||
return not (s1_end < s2_start or s2_end < s1_start)
|
||||
|
||||
class ChromaCollector(Collecter):
|
||||
def __init__(self, embedder: Embedder):
|
||||
super().__init__()
|
||||
|
||||
class ChromaCollector():
|
||||
def __init__(self):
|
||||
name = ''.join(random.choice('ab') for _ in range(10))
|
||||
|
||||
self.name = name
|
||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
self.embedder = embedder
|
||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed)
|
||||
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||
|
||||
self.ids = []
|
||||
self.id_to_info = {}
|
||||
self.embeddings_cache = {}
|
||||
self.lock = threading.Lock() # Locking so the server doesn't break.
|
||||
self.lock = threading.Lock() # Locking so the server doesn't break.
|
||||
|
||||
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
|
||||
with self.lock:
|
||||
assert metadatas is None or len(metadatas) == len(texts), "metadatas must be None or have the same length as texts"
|
||||
|
||||
if len(texts) == 0:
|
||||
|
||||
if len(texts) == 0:
|
||||
return
|
||||
|
||||
new_ids = self._get_new_ids(len(texts))
|
||||
|
||||
(existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||||
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
|
||||
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
|
||||
|
||||
# If there are any already existing texts, add them all at once.
|
||||
if existing_texts:
|
||||
logger.info(f'Adding {len(existing_embeddings)} cached embeddings.')
|
||||
args = {'embeddings': existing_embeddings, 'documents': existing_texts, 'ids': existing_ids}
|
||||
if metadatas is not None:
|
||||
if metadatas is not None:
|
||||
args['metadatas'] = existing_metas
|
||||
self.collection.add(**args)
|
||||
|
||||
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
|
||||
if non_existing_texts:
|
||||
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist()
|
||||
non_existing_embeddings = embedder(non_existing_texts)
|
||||
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
|
||||
self.embeddings_cache[text] = embedding
|
||||
|
||||
logger.info(f'Adding {len(non_existing_embeddings)} new embeddings.')
|
||||
args = {'embeddings': non_existing_embeddings, 'documents': non_existing_texts, 'ids': non_existing_ids}
|
||||
if metadatas is not None:
|
||||
if metadatas is not None:
|
||||
args['metadatas'] = non_existing_metas
|
||||
self.collection.add(**args)
|
||||
|
||||
@ -145,7 +129,6 @@ class ChromaCollector(Collecter):
|
||||
self.id_to_info.update(new_info)
|
||||
self.ids.extend(new_ids)
|
||||
|
||||
|
||||
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
|
||||
existing_texts, non_existing_texts = [], []
|
||||
existing_embeddings = []
|
||||
@ -169,7 +152,6 @@ class ChromaCollector(Collecter):
|
||||
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||||
(non_existing_texts, non_existing_ids, non_existing_metas)
|
||||
|
||||
|
||||
def _get_new_ids(self, num_new_ids: int):
|
||||
if self.ids:
|
||||
max_existing_id = max(int(id_) for id_ in self.ids)
|
||||
@ -178,7 +160,6 @@ class ChromaCollector(Collecter):
|
||||
|
||||
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
|
||||
|
||||
|
||||
def _find_min_max_start_index(self):
|
||||
max_index, min_index = 0, float('inf')
|
||||
for _, val in self.id_to_info.items():
|
||||
@ -188,34 +169,34 @@ class ChromaCollector(Collecter):
|
||||
min_index = val['start_index']
|
||||
return min_index, max_index
|
||||
|
||||
|
||||
# NB: Does not make sense to weigh excerpts from different documents.
|
||||
# NB: Does not make sense to weigh excerpts from different documents.
|
||||
# But let's say that's the user's problem. Perfect world scenario:
|
||||
# Apply time weighing to different documents. For each document, then, add
|
||||
# separate time weighing.
|
||||
|
||||
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
|
||||
sigmoid = lambda x: 1 / (1 + np.exp(-x))
|
||||
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
|
||||
|
||||
# Scale to [0,time_power] and shift it up to [1-time_power, 1]
|
||||
weights = weights - min(weights)
|
||||
weights = weights - min(weights)
|
||||
weights = weights * (time_power / max(weights))
|
||||
weights = weights + (1 - time_power)
|
||||
weights = weights + (1 - time_power)
|
||||
|
||||
# Reverse the weights
|
||||
weights = weights[::-1]
|
||||
weights = weights[::-1]
|
||||
|
||||
for info in infos:
|
||||
index = info.start_index
|
||||
info.distance *= weights[index]
|
||||
|
||||
|
||||
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
|
||||
# Ensure there are infos to filter
|
||||
if not infos:
|
||||
return []
|
||||
|
||||
|
||||
# Find info with minimum distance
|
||||
min_info = min(infos, key=lambda x: x.distance)
|
||||
|
||||
@ -231,7 +212,6 @@ class ChromaCollector(Collecter):
|
||||
|
||||
return filtered_infos
|
||||
|
||||
|
||||
def _merge_infos(self, infos: list[Info]):
|
||||
merged_infos = []
|
||||
current_info = infos[0]
|
||||
@ -247,8 +227,8 @@ class ChromaCollector(Collecter):
|
||||
merged_infos.append(current_info)
|
||||
return merged_infos
|
||||
|
||||
|
||||
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
|
||||
|
||||
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
||||
n_results = min(len(self.ids), n_results)
|
||||
if n_results == 0:
|
||||
@ -262,11 +242,11 @@ class ChromaCollector(Collecter):
|
||||
|
||||
for search_string in search_strings:
|
||||
result = self.collection.query(query_texts=search_string, n_results=math.ceil(n_results / len(search_strings)), include=['distances'])
|
||||
curr_infos = [Info(start_index=self.id_to_info[id]['start_index'],
|
||||
text_with_context=self.id_to_info[id]['text_with_context'],
|
||||
distance=distance, id=id)
|
||||
curr_infos = [Info(start_index=self.id_to_info[id]['start_index'],
|
||||
text_with_context=self.id_to_info[id]['text_with_context'],
|
||||
distance=distance, id=id)
|
||||
for id, distance in zip(result['ids'][0], result['distances'][0])]
|
||||
|
||||
|
||||
self._apply_sigmoid_time_weighing(infos=curr_infos, document_len=max_start_index - min_start_index + 1, time_steepness=parameters.get_time_steepness(), time_power=parameters.get_time_power())
|
||||
curr_infos = self._filter_outliers_by_median_distance(curr_infos, parameters.get_significant_level())
|
||||
infos.extend(curr_infos)
|
||||
@ -279,23 +259,23 @@ class ChromaCollector(Collecter):
|
||||
distances = [inf.distance for inf in infos]
|
||||
|
||||
return texts_with_context, ids, distances
|
||||
|
||||
|
||||
# Get chunks by similarity
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
with self.lock:
|
||||
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||
return documents
|
||||
|
||||
|
||||
# Get ids by similarity
|
||||
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
with self.lock:
|
||||
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||
return ids
|
||||
|
||||
|
||||
|
||||
# Cutoff token count
|
||||
|
||||
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
|
||||
# TODO: Move to caller; We add delimiters there which might go over the limit.
|
||||
current_token_count = 0
|
||||
@ -308,7 +288,7 @@ class ChromaCollector(Collecter):
|
||||
# If adding this document would exceed the max token count,
|
||||
# truncate the document to fit within the limit.
|
||||
remaining_tokens = max_token_count - current_token_count
|
||||
|
||||
|
||||
truncated_doc = decode(doc_tokens[:remaining_tokens], skip_special_tokens=True)
|
||||
return_documents.append(truncated_doc)
|
||||
break
|
||||
@ -317,29 +297,28 @@ class ChromaCollector(Collecter):
|
||||
current_token_count += doc_token_count
|
||||
|
||||
return return_documents
|
||||
|
||||
|
||||
# Get chunks by similarity and then sort by ids
|
||||
|
||||
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||||
with self.lock:
|
||||
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||
sorted_docs = [x for _, x in sorted(zip(ids, documents))]
|
||||
|
||||
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||||
|
||||
|
||||
|
||||
# Get chunks by similarity and then sort by distance (lowest distance is last).
|
||||
|
||||
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||||
with self.lock:
|
||||
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
|
||||
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
|
||||
|
||||
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
|
||||
|
||||
# If a document is truncated or competely skipped, it would be with high distance.
|
||||
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||||
return_documents.reverse() # highest -> lowest
|
||||
return_documents.reverse() # highest -> lowest
|
||||
|
||||
return return_documents
|
||||
|
||||
|
||||
def delete(self, ids_to_delete: list[str], where: dict):
|
||||
with self.lock:
|
||||
@ -354,23 +333,16 @@ class ChromaCollector(Collecter):
|
||||
|
||||
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
|
||||
|
||||
|
||||
def clear(self):
|
||||
with self.lock:
|
||||
self.chroma_client.reset()
|
||||
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed)
|
||||
|
||||
self.ids = []
|
||||
self.id_to_info = {}
|
||||
self.chroma_client.delete_collection(name=self.name)
|
||||
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||
|
||||
logger.info('Successfully cleared all records and reset chromaDB.')
|
||||
|
||||
|
||||
class SentenceTransformerEmbedder(Embedder):
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Creating Sentence Embedder...')
|
||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||||
self.embed = self.model.encode
|
||||
|
||||
|
||||
def make_collector():
|
||||
return ChromaCollector(SentenceTransformerEmbedder())
|
||||
return ChromaCollector()
|
||||
|
@ -11,32 +11,29 @@ This module contains utils for preprocessing the text before converting it to em
|
||||
* removing specific parts of speech (adverbs and interjections)
|
||||
- TextSummarizer extracts the most important sentences from a long string using text-ranking.
|
||||
"""
|
||||
import pytextrank
|
||||
import string
|
||||
import spacy
|
||||
import math
|
||||
import nltk
|
||||
import re
|
||||
import string
|
||||
|
||||
import nltk
|
||||
import spacy
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
from num2words import num2words
|
||||
|
||||
|
||||
class TextPreprocessorBuilder:
|
||||
# Define class variables as None initially
|
||||
# Define class variables as None initially
|
||||
_stop_words = set(stopwords.words('english'))
|
||||
_lemmatizer = WordNetLemmatizer()
|
||||
|
||||
|
||||
# Some of the functions are expensive. We cache the results.
|
||||
_lemmatizer_cache = {}
|
||||
_pos_remove_cache = {}
|
||||
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
|
||||
def to_lower(self):
|
||||
# Match both words and non-word characters
|
||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||
@ -49,7 +46,6 @@ class TextPreprocessorBuilder:
|
||||
self.text = "".join(tokens)
|
||||
return self
|
||||
|
||||
|
||||
def num_to_word(self, min_len: int = 1):
|
||||
# Match both words and non-word characters
|
||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||
@ -58,11 +54,10 @@ class TextPreprocessorBuilder:
|
||||
if token.isdigit() and len(token) >= min_len:
|
||||
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
||||
# 740700 will become "seven hundred and forty thousand seven hundred".
|
||||
tokens[i] = num2words(int(token)).replace(",","") # Remove commas from num2words.
|
||||
tokens[i] = num2words(int(token)).replace(",", "") # Remove commas from num2words.
|
||||
self.text = "".join(tokens)
|
||||
return self
|
||||
|
||||
|
||||
def num_to_char_long(self, min_len: int = 1):
|
||||
# Match both words and non-word characters
|
||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||
@ -71,11 +66,13 @@ class TextPreprocessorBuilder:
|
||||
if token.isdigit() and len(token) >= min_len:
|
||||
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
||||
# 740700 will become HHHHHHEEEEEAAAAHHHAAA
|
||||
convert_token = lambda token: ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
|
||||
def convert_token(token):
|
||||
return ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
|
||||
|
||||
tokens[i] = convert_token(tokens[i])
|
||||
self.text = "".join(tokens)
|
||||
return self
|
||||
|
||||
|
||||
def num_to_char(self, min_len: int = 1):
|
||||
# Match both words and non-word characters
|
||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||
@ -87,15 +84,15 @@ class TextPreprocessorBuilder:
|
||||
tokens[i] = ''.join(chr(int(digit) + 65) for digit in token)
|
||||
self.text = "".join(tokens)
|
||||
return self
|
||||
|
||||
|
||||
def merge_spaces(self):
|
||||
self.text = re.sub(' +', ' ', self.text)
|
||||
return self
|
||||
|
||||
|
||||
def strip(self):
|
||||
self.text = self.text.strip()
|
||||
return self
|
||||
|
||||
|
||||
def remove_punctuation(self):
|
||||
self.text = self.text.translate(str.maketrans('', '', string.punctuation))
|
||||
return self
|
||||
@ -103,7 +100,7 @@ class TextPreprocessorBuilder:
|
||||
def remove_stopwords(self):
|
||||
self.text = "".join([word for word in re.findall(r'\b\w+\b|\W+', self.text) if word not in TextPreprocessorBuilder._stop_words])
|
||||
return self
|
||||
|
||||
|
||||
def remove_specific_pos(self):
|
||||
"""
|
||||
In the English language, adverbs and interjections rarely provide meaningul information.
|
||||
@ -140,7 +137,7 @@ class TextPreprocessorBuilder:
|
||||
if processed_text:
|
||||
self.text = processed_text
|
||||
return self
|
||||
|
||||
|
||||
new_text = "".join([TextPreprocessorBuilder._lemmatizer.lemmatize(word) for word in re.findall(r'\b\w+\b|\W+', self.text)])
|
||||
TextPreprocessorBuilder._lemmatizer_cache[self.text] = new_text
|
||||
self.text = new_text
|
||||
@ -150,6 +147,7 @@ class TextPreprocessorBuilder:
|
||||
def build(self):
|
||||
return self.text
|
||||
|
||||
|
||||
class TextSummarizer:
|
||||
_nlp_pipeline = None
|
||||
_cache = {}
|
||||
@ -165,7 +163,7 @@ class TextSummarizer:
|
||||
@staticmethod
|
||||
def process_long_text(text: str, min_num_sent: int) -> list[str]:
|
||||
"""
|
||||
This function applies a text summarization process on a given text string, extracting
|
||||
This function applies a text summarization process on a given text string, extracting
|
||||
the most important sentences based on the principle that 20% of the content is responsible
|
||||
for 80% of the meaning (the Pareto Principle).
|
||||
|
||||
@ -193,7 +191,7 @@ class TextSummarizer:
|
||||
|
||||
else:
|
||||
result = [text]
|
||||
|
||||
|
||||
# Store the result in cache before returning it
|
||||
TextSummarizer._cache[cache_key] = result
|
||||
return result
|
||||
return result
|
||||
|
@ -1,16 +1,17 @@
|
||||
"""
|
||||
This module is responsible for processing the corpus and feeding it into chromaDB. It will receive a corpus of text.
|
||||
This module is responsible for processing the corpus and feeding it into chromaDB. It will receive a corpus of text.
|
||||
It will then split it into chunks of specified length. For each of those chunks, it will append surrounding context.
|
||||
It will only include full words.
|
||||
"""
|
||||
|
||||
import re
|
||||
import bisect
|
||||
import re
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
|
||||
from .chromadb import ChromaCollector
|
||||
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
|
||||
|
||||
|
||||
def preprocess_text_no_summary(text) -> str:
|
||||
builder = TextPreprocessorBuilder(text)
|
||||
@ -42,7 +43,7 @@ def preprocess_text_no_summary(text) -> str:
|
||||
builder.num_to_char(parameters.get_min_num_length())
|
||||
elif parameters.get_num_conversion_strategy() == parameters.NUM_TO_CHAR_LONG_METHOD:
|
||||
builder.num_to_char_long(parameters.get_min_num_length())
|
||||
|
||||
|
||||
return builder.build()
|
||||
|
||||
|
||||
@ -53,10 +54,10 @@ def preprocess_text(text) -> list[str]:
|
||||
|
||||
def _create_chunks_with_context(corpus, chunk_len, context_left, context_right):
|
||||
"""
|
||||
This function takes a corpus of text and splits it into chunks of a specified length,
|
||||
then adds a specified amount of context to each chunk. The context is added by first
|
||||
going backwards from the start of the chunk and then going forwards from the end of the
|
||||
chunk, ensuring that the context includes only whole words and that the total context length
|
||||
This function takes a corpus of text and splits it into chunks of a specified length,
|
||||
then adds a specified amount of context to each chunk. The context is added by first
|
||||
going backwards from the start of the chunk and then going forwards from the end of the
|
||||
chunk, ensuring that the context includes only whole words and that the total context length
|
||||
does not exceed the specified limit. This function uses binary search for efficiency.
|
||||
|
||||
Returns:
|
||||
@ -102,7 +103,7 @@ def _create_chunks_with_context(corpus, chunk_len, context_left, context_right):
|
||||
# Combine all the words in the context range (before, chunk, and after)
|
||||
chunk_with_context = ''.join(words[context_start_index:context_end_index])
|
||||
chunks_with_context.append(chunk_with_context)
|
||||
|
||||
|
||||
# Determine the start index of the chunk with context
|
||||
chunk_with_context_start_index = word_start_indices[context_start_index]
|
||||
chunk_with_context_start_indices.append(chunk_with_context_start_index)
|
||||
@ -125,9 +126,9 @@ def _clear_chunks(data_chunks, data_chunks_with_context, data_chunk_starting_ind
|
||||
seen_chunk_start = seen_chunks.get(chunk)
|
||||
if seen_chunk_start:
|
||||
# If we've already seen this exact chunk, and the context around it it very close to the seen chunk, then skip it.
|
||||
if abs(seen_chunk_start-index) < parameters.get_delta_start():
|
||||
if abs(seen_chunk_start - index) < parameters.get_delta_start():
|
||||
continue
|
||||
|
||||
|
||||
distinct_data_chunks.append(chunk)
|
||||
distinct_data_chunks_with_context.append(context)
|
||||
distinct_data_chunk_starting_indices.append(index)
|
||||
@ -206,4 +207,4 @@ def process_and_add_to_collector(corpus: str, collector: ChromaCollector, clear_
|
||||
|
||||
if clear_collector_before_adding:
|
||||
collector.clear()
|
||||
collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata]*len(data_chunks) if metadata is not None else None)
|
||||
collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata] * len(data_chunks) if metadata is not None else None)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import concurrent.futures
|
||||
import requests
|
||||
import re
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
@ -9,6 +9,7 @@ import extensions.superboogav2.parameters as parameters
|
||||
from .data_processor import process_and_add_to_collector
|
||||
from .utils import create_metadata_source
|
||||
|
||||
|
||||
def _download_single(url):
|
||||
response = requests.get(url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
@ -62,4 +63,4 @@ def feed_url_into_collector(urls, collector):
|
||||
text = '\n'.join([s.strip() for s in strings])
|
||||
all_text += text
|
||||
|
||||
process_and_add_to_collector(all_text, collector, False, create_metadata_source('url-download'))
|
||||
process_and_add_to_collector(all_text, collector, False, create_metadata_source('url-download'))
|
||||
|
@ -4,13 +4,12 @@ This module is responsible for handling and modifying the notebook text.
|
||||
import re
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from extensions.superboogav2.utils import create_context_text
|
||||
from modules.logging_colors import logger
|
||||
|
||||
from .data_processor import preprocess_text
|
||||
|
||||
|
||||
def _remove_special_tokens(string):
|
||||
pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
|
||||
return re.sub(pattern, '', string)
|
||||
@ -37,4 +36,4 @@ def input_modifier_internal(string, collector, is_chat):
|
||||
# Make the injection
|
||||
string = string.replace('<|injection-point|>', create_context_text(results))
|
||||
|
||||
return _remove_special_tokens(string)
|
||||
return _remove_special_tokens(string)
|
||||
|
@ -3,22 +3,24 @@ This module implements a hyperparameter optimization routine for the embedding a
|
||||
|
||||
Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found.
|
||||
"""
|
||||
import re
|
||||
import hashlib
|
||||
import json
|
||||
import optuna
|
||||
import logging
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import logging
|
||||
import hashlib
|
||||
logging.getLogger('optuna').setLevel(logging.WARNING)
|
||||
import optuna
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
logging.getLogger('optuna').setLevel(logging.WARNING)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
from modules.logging_colors import logger
|
||||
|
||||
from .benchmark import benchmark
|
||||
from .parameters import Parameters
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
# Format the parameters into markdown format.
|
||||
@ -28,7 +30,7 @@ def _markdown_hyperparams():
|
||||
# Escape any markdown syntax
|
||||
param_name = re.sub(r"([_*\[\]()~`>#+-.!])", r"\\\1", param_name)
|
||||
param_value_default = re.sub(r"([_*\[\]()~`>#+-.!])", r"\\\1", str(param_value['default'])) if param_value['default'] else ' '
|
||||
|
||||
|
||||
res.append('* {}: **{}**'.format(param_name, param_value_default))
|
||||
|
||||
return '\n'.join(res)
|
||||
@ -49,13 +51,13 @@ def _convert_np_types(params):
|
||||
# Set the default values for the hyperparameters.
|
||||
def _set_hyperparameters(params):
|
||||
for param_name, param_value in params.items():
|
||||
if param_name in Parameters.getInstance().hyperparameters:
|
||||
if param_name in Parameters.getInstance().hyperparameters:
|
||||
Parameters.getInstance().hyperparameters[param_name]['default'] = param_value
|
||||
|
||||
|
||||
# Check if the parameter is for optimization.
|
||||
def _is_optimization_param(val):
|
||||
is_opt = val.get('should_optimize', False) # Either does not exist or is false
|
||||
is_opt = val.get('should_optimize', False) # Either does not exist or is false
|
||||
return is_opt
|
||||
|
||||
|
||||
@ -67,7 +69,7 @@ def _get_params_hash(params):
|
||||
|
||||
def optimize(collector, progress=gr.Progress()):
|
||||
# Inform the user that something is happening.
|
||||
progress(0, desc=f'Setting Up...')
|
||||
progress(0, desc='Setting Up...')
|
||||
|
||||
# Track the current step
|
||||
current_step = 0
|
||||
@ -132,4 +134,4 @@ def optimize(collector, progress=gr.Progress()):
|
||||
with open('best_params.json', 'w') as fp:
|
||||
json.dump(_convert_np_types(best_params), fp, indent=4)
|
||||
|
||||
return str_result
|
||||
return str_result
|
||||
|
@ -1,18 +1,16 @@
|
||||
"""
|
||||
This module provides a singleton class `Parameters` that is used to manage all hyperparameters for the embedding application.
|
||||
This module provides a singleton class `Parameters` that is used to manage all hyperparameters for the embedding application.
|
||||
It expects a JSON file in `extensions/superboogav2/config.json`.
|
||||
|
||||
Each element in the JSON must have a `default` value which will be used for the current run. Elements can have `categories`.
|
||||
These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`,
|
||||
Each element in the JSON must have a `default` value which will be used for the current run. Elements can have `categories`.
|
||||
These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`,
|
||||
then the optimizer will only ever use the default value.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import json
|
||||
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
NUM_TO_WORD_METHOD = 'Number to Word'
|
||||
NUM_TO_CHAR_METHOD = 'Number to Char'
|
||||
NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char'
|
||||
@ -366,4 +364,4 @@ def set_api_port(value: int):
|
||||
|
||||
|
||||
def set_api_on(value: bool):
|
||||
Parameters.getInstance().hyperparameters['api_on']['default'] = value
|
||||
Parameters.getInstance().hyperparameters['api_on']['default'] = value
|
||||
|
@ -1,5 +1,5 @@
|
||||
beautifulsoup4==4.12.2
|
||||
chromadb==0.3.18
|
||||
chromadb==0.4.24
|
||||
lxml
|
||||
optuna
|
||||
pandas==2.0.3
|
||||
@ -7,4 +7,4 @@ posthog==2.4.2
|
||||
sentence_transformers==2.2.2
|
||||
spacy
|
||||
pytextrank
|
||||
num2words
|
||||
num2words
|
||||
|
@ -7,28 +7,29 @@ from pathlib import Path
|
||||
# Point to where nltk will find the required data.
|
||||
os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve())
|
||||
|
||||
import textwrap
|
||||
import codecs
|
||||
import textwrap
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
from modules.logging_colors import logger
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
from .utils import create_metadata_source
|
||||
from .chromadb import make_collector
|
||||
from .download_urls import feed_url_into_collector
|
||||
from .data_processor import process_and_add_to_collector
|
||||
from .benchmark import benchmark
|
||||
from .optimize import optimize
|
||||
from .notebook_handler import input_modifier_internal
|
||||
from .chat_handler import custom_generate_chat_prompt_internal
|
||||
from .api import APIManager
|
||||
from .benchmark import benchmark
|
||||
from .chat_handler import custom_generate_chat_prompt_internal
|
||||
from .chromadb import make_collector
|
||||
from .data_processor import process_and_add_to_collector
|
||||
from .download_urls import feed_url_into_collector
|
||||
from .notebook_handler import input_modifier_internal
|
||||
from .optimize import optimize
|
||||
from .utils import create_metadata_source
|
||||
|
||||
collector = None
|
||||
api_manager = None
|
||||
|
||||
|
||||
def setup():
|
||||
global collector
|
||||
global api_manager
|
||||
@ -38,6 +39,7 @@ def setup():
|
||||
if parameters.get_api_on():
|
||||
api_manager.start_server(parameters.get_api_port())
|
||||
|
||||
|
||||
def _feed_data_into_collector(corpus):
|
||||
yield '### Processing data...'
|
||||
process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text'))
|
||||
@ -87,7 +89,7 @@ def _get_optimizable_settings() -> list:
|
||||
preprocess_pipeline.append('Merge Spaces')
|
||||
if parameters.should_strip():
|
||||
preprocess_pipeline.append('Strip Edges')
|
||||
|
||||
|
||||
return [
|
||||
parameters.get_time_power(),
|
||||
parameters.get_time_steepness(),
|
||||
@ -104,8 +106,8 @@ def _get_optimizable_settings() -> list:
|
||||
]
|
||||
|
||||
|
||||
def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||
def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup):
|
||||
logger.debug('Applying settings.')
|
||||
|
||||
@ -240,7 +242,7 @@ def ui():
|
||||
with gr.Tab("File input"):
|
||||
file_input = gr.File(label='Input file', type='binary')
|
||||
update_file = gr.Button('Load data')
|
||||
|
||||
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Accordion("Processing settings", open=True):
|
||||
chunk_len = gr.Textbox(value=parameters.get_chunk_len(), label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
|
||||
@ -305,19 +307,16 @@ def ui():
|
||||
optimize_button = gr.Button('Optimize')
|
||||
optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True)
|
||||
|
||||
|
||||
clear_button = gr.Button('❌ Clear Data')
|
||||
|
||||
|
||||
with gr.Column():
|
||||
last_updated = gr.Markdown()
|
||||
|
||||
all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||
all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup]
|
||||
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, chunk_count, context_len, chunk_len]
|
||||
|
||||
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, chunk_count, context_len, chunk_len]
|
||||
|
||||
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
|
||||
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
|
||||
@ -326,7 +325,6 @@ def ui():
|
||||
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
|
||||
clear_button.click(_clear_data, [], last_updated, show_progress=False)
|
||||
|
||||
|
||||
optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
@ -352,4 +350,4 @@ def ui():
|
||||
chunk_regex.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
chunk_len.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
threads.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
strong_cleanup.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
strong_cleanup.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
|
@ -4,6 +4,7 @@ This module contains common functions across multiple other modules.
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
|
||||
# Create the context using the prefix + data_separator + postfix from parameters.
|
||||
def create_context_text(results):
|
||||
context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix()
|
||||
@ -13,4 +14,4 @@ def create_context_text(results):
|
||||
|
||||
# Create metadata with the specified source
|
||||
def create_metadata_source(source: str):
|
||||
return {'source': source}
|
||||
return {'source': source}
|
||||
|
Loading…
Reference in New Issue
Block a user