mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 00:46:53 +01:00
Make superbooga & superboogav2 functional again (#5656)
This commit is contained in:
parent
bae14c8f13
commit
2681f6f640
@ -1,43 +1,24 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import posthog
|
import posthog
|
||||||
import torch
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from sentence_transformers import SentenceTransformer
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
# Intercept calls to posthog
|
||||||
|
|
||||||
logger.info('Intercepting all calls to posthog :)')
|
|
||||||
posthog.capture = lambda *args, **kwargs: None
|
posthog.capture = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
|
||||||
class Collecter():
|
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaCollector():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
name = ''.join(random.choice('ab') for _ in range(10))
|
||||||
|
|
||||||
def add(self, texts: list[str]):
|
self.name = name
|
||||||
pass
|
|
||||||
|
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Embedder():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def embed(self, text: str) -> list[torch.Tensor]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChromaCollector(Collecter):
|
|
||||||
def __init__(self, embedder: Embedder):
|
|
||||||
super().__init__()
|
|
||||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||||
self.embedder = embedder
|
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
|
|
||||||
self.ids = []
|
self.ids = []
|
||||||
|
|
||||||
def add(self, texts: list[str]):
|
def add(self, texts: list[str]):
|
||||||
@ -102,24 +83,15 @@ class ChromaCollector(Collecter):
|
|||||||
return sorted(ids)
|
return sorted(ids)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.collection.delete(ids=self.ids)
|
|
||||||
self.ids = []
|
self.ids = []
|
||||||
|
self.chroma_client.delete_collection(name=self.name)
|
||||||
|
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||||
class SentenceTransformerEmbedder(Embedder):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
|
||||||
self.embed = self.model.encode
|
|
||||||
|
|
||||||
|
|
||||||
def make_collector():
|
def make_collector():
|
||||||
global embedder
|
return ChromaCollector()
|
||||||
return ChromaCollector(embedder)
|
|
||||||
|
|
||||||
|
|
||||||
def add_chunks_to_collector(chunks, collector):
|
def add_chunks_to_collector(chunks, collector):
|
||||||
collector.clear()
|
collector.clear()
|
||||||
collector.add(chunks)
|
collector.add(chunks)
|
||||||
|
|
||||||
|
|
||||||
embedder = SentenceTransformerEmbedder()
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
beautifulsoup4==4.12.2
|
beautifulsoup4==4.12.2
|
||||||
chromadb==0.3.18
|
chromadb==0.4.24
|
||||||
pandas==2.0.3
|
pandas==2.0.3
|
||||||
posthog==2.4.2
|
posthog==2.4.2
|
||||||
sentence_transformers==2.2.2
|
sentence_transformers==2.2.2
|
||||||
|
@ -12,17 +12,16 @@ This module is responsible for the VectorDB API. It currently supports:
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from urllib.parse import urlparse, parse_qs
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
import extensions.superboogav2.parameters as parameters
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .chromadb import ChromaCollector
|
from .chromadb import ChromaCollector
|
||||||
from .data_processor import process_and_add_to_collector
|
from .data_processor import process_and_add_to_collector
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
|
||||||
|
|
||||||
|
|
||||||
class CustomThreadingHTTPServer(ThreadingHTTPServer):
|
class CustomThreadingHTTPServer(ThreadingHTTPServer):
|
||||||
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
|
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
|
||||||
@ -38,7 +37,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.collector = collector
|
self.collector = collector
|
||||||
super().__init__(request, client_address, server)
|
super().__init__(request, client_address, server)
|
||||||
|
|
||||||
|
|
||||||
def _send_412_error(self, message):
|
def _send_412_error(self, message):
|
||||||
self.send_response(412)
|
self.send_response(412)
|
||||||
self.send_header("Content-type", "application/json")
|
self.send_header("Content-type", "application/json")
|
||||||
@ -46,7 +44,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
response = json.dumps({"error": message})
|
response = json.dumps({"error": message})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def _send_404_error(self):
|
def _send_404_error(self):
|
||||||
self.send_response(404)
|
self.send_response(404)
|
||||||
self.send_header("Content-type", "application/json")
|
self.send_header("Content-type", "application/json")
|
||||||
@ -54,7 +51,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
response = json.dumps({"error": "Resource not found"})
|
response = json.dumps({"error": "Resource not found"})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def _send_400_error(self, error_message: str):
|
def _send_400_error(self, error_message: str):
|
||||||
self.send_response(400)
|
self.send_response(400)
|
||||||
self.send_header("Content-type", "application/json")
|
self.send_header("Content-type", "application/json")
|
||||||
@ -62,7 +58,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
response = json.dumps({"error": error_message})
|
response = json.dumps({"error": error_message})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def _send_200_response(self, message: str):
|
def _send_200_response(self, message: str):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("Content-type", "application/json")
|
self.send_header("Content-type", "application/json")
|
||||||
@ -75,24 +70,21 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
|
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
|
||||||
if sort_param == parameters.SORT_DISTANCE:
|
if sort_param == parameters.SORT_DISTANCE:
|
||||||
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
||||||
elif sort_param == parameters.SORT_ID:
|
elif sort_param == parameters.SORT_ID:
|
||||||
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
|
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
|
||||||
else: # Default is dist
|
else: # Default is dist
|
||||||
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"results": results
|
"results": results
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
self._send_404_error()
|
self._send_404_error()
|
||||||
|
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
try:
|
try:
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
@ -146,7 +138,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._send_400_error(str(e))
|
self._send_400_error(str(e))
|
||||||
|
|
||||||
|
|
||||||
def do_DELETE(self):
|
def do_DELETE(self):
|
||||||
try:
|
try:
|
||||||
parsed_path = urlparse(self.path)
|
parsed_path = urlparse(self.path)
|
||||||
@ -161,12 +152,10 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._send_400_error(str(e))
|
self._send_400_error(str(e))
|
||||||
|
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def do_OPTIONS(self):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
|
|
||||||
def end_headers(self):
|
def end_headers(self):
|
||||||
self.send_header('Access-Control-Allow-Origin', '*')
|
self.send_header('Access-Control-Allow-Origin', '*')
|
||||||
self.send_header('Access-Control-Allow-Methods', '*')
|
self.send_header('Access-Control-Allow-Methods', '*')
|
||||||
@ -197,7 +186,7 @@ class APIManager:
|
|||||||
|
|
||||||
def stop_server(self):
|
def stop_server(self):
|
||||||
if self.server is not None:
|
if self.server is not None:
|
||||||
logger.info(f'Stopping chromaDB API.')
|
logger.info('Stopping chromaDB API.')
|
||||||
self.server.shutdown()
|
self.server.shutdown()
|
||||||
self.server.server_close()
|
self.server.server_close()
|
||||||
self.server = None
|
self.server = None
|
||||||
|
@ -9,13 +9,13 @@ The benchmark function will return the score as an integer.
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .data_processor import process_and_add_to_collector, preprocess_text
|
from .data_processor import preprocess_text, process_and_add_to_collector
|
||||||
from .parameters import get_chunk_count, get_max_token_count
|
from .parameters import get_chunk_count, get_max_token_count
|
||||||
from .utils import create_metadata_source
|
from .utils import create_metadata_source
|
||||||
|
|
||||||
|
|
||||||
def benchmark(config_path, collector):
|
def benchmark(config_path, collector):
|
||||||
# Get the current system date
|
# Get the current system date
|
||||||
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
@ -4,16 +4,17 @@ This module is responsible for modifying the chat prompt and history.
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
from extensions.superboogav2.utils import (
|
||||||
|
create_context_text,
|
||||||
|
create_metadata_source
|
||||||
|
)
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
from modules.text_generation import get_encoded_length
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules.chat import load_character_memoized
|
from modules.chat import load_character_memoized
|
||||||
from extensions.superboogav2.utils import create_context_text, create_metadata_source
|
from modules.logging_colors import logger
|
||||||
|
from modules.text_generation import get_encoded_length
|
||||||
|
|
||||||
from .data_processor import process_and_add_to_collector
|
|
||||||
from .chromadb import ChromaCollector
|
from .chromadb import ChromaCollector
|
||||||
|
from .data_processor import process_and_add_to_collector
|
||||||
|
|
||||||
CHAT_METADATA = create_metadata_source('automatic-chat-insert')
|
CHAT_METADATA = create_metadata_source('automatic-chat-insert')
|
||||||
|
|
||||||
@ -69,7 +70,7 @@ def _concatinate_history(history: dict, state: dict):
|
|||||||
if len(exchange) >= 2:
|
if len(exchange) >= 2:
|
||||||
full_history_text += _format_single_exchange(bot_name, exchange[1])
|
full_history_text += _format_single_exchange(bot_name, exchange[1])
|
||||||
|
|
||||||
return full_history_text[:-1] # Remove the last new line.
|
return full_history_text[:-1] # Remove the last new line.
|
||||||
|
|
||||||
|
|
||||||
def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
|
def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
|
||||||
@ -95,7 +96,7 @@ def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
|
|||||||
else:
|
else:
|
||||||
# replace the message at replace_position with context_text
|
# replace the message at replace_position with context_text
|
||||||
i, j = replace_position
|
i, j = replace_position
|
||||||
history['internal'][-i-1][-j-1] = context_text
|
history['internal'][-i - 1][-j - 1] = context_text
|
||||||
|
|
||||||
|
|
||||||
def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
|
def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
|
||||||
|
@ -1,42 +1,23 @@
|
|||||||
import threading
|
|
||||||
import chromadb
|
|
||||||
import posthog
|
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import chromadb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import extensions.superboogav2.parameters as parameters
|
import posthog
|
||||||
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from sentence_transformers import SentenceTransformer
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
import extensions.superboogav2.parameters as parameters
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import encode, decode
|
from modules.text_generation import decode, encode
|
||||||
|
|
||||||
logger.debug('Intercepting all calls to posthog.')
|
# Intercept calls to posthog
|
||||||
posthog.capture = lambda *args, **kwargs: None
|
posthog.capture = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
|
||||||
class Collecter():
|
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Embedder():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def embed(self, text: str) -> list[torch.Tensor]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Info:
|
class Info:
|
||||||
def __init__(self, start_index, text_with_context, distance, id):
|
def __init__(self, start_index, text_with_context, distance, id):
|
||||||
@ -58,7 +39,7 @@ class Info:
|
|||||||
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
|
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
|
||||||
# Arithmetic mean
|
# Arithmetic mean
|
||||||
return (self.distance + other_info.distance) / 2
|
return (self.distance + other_info.distance) / 2
|
||||||
else: # Min is default
|
else: # Min is default
|
||||||
return min(self.distance, other_info.distance)
|
return min(self.distance, other_info.distance)
|
||||||
|
|
||||||
def merge_with(self, other_info):
|
def merge_with(self, other_info):
|
||||||
@ -93,16 +74,19 @@ class Info:
|
|||||||
|
|
||||||
return not (s1_end < s2_start or s2_end < s1_start)
|
return not (s1_end < s2_start or s2_end < s1_start)
|
||||||
|
|
||||||
class ChromaCollector(Collecter):
|
|
||||||
def __init__(self, embedder: Embedder):
|
class ChromaCollector():
|
||||||
super().__init__()
|
def __init__(self):
|
||||||
|
name = ''.join(random.choice('ab') for _ in range(10))
|
||||||
|
|
||||||
|
self.name = name
|
||||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||||
self.embedder = embedder
|
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed)
|
|
||||||
self.ids = []
|
self.ids = []
|
||||||
self.id_to_info = {}
|
self.id_to_info = {}
|
||||||
self.embeddings_cache = {}
|
self.embeddings_cache = {}
|
||||||
self.lock = threading.Lock() # Locking so the server doesn't break.
|
self.lock = threading.Lock() # Locking so the server doesn't break.
|
||||||
|
|
||||||
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
|
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
@ -114,7 +98,7 @@ class ChromaCollector(Collecter):
|
|||||||
new_ids = self._get_new_ids(len(texts))
|
new_ids = self._get_new_ids(len(texts))
|
||||||
|
|
||||||
(existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
(existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||||||
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
|
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
|
||||||
|
|
||||||
# If there are any already existing texts, add them all at once.
|
# If there are any already existing texts, add them all at once.
|
||||||
if existing_texts:
|
if existing_texts:
|
||||||
@ -126,7 +110,7 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
|
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
|
||||||
if non_existing_texts:
|
if non_existing_texts:
|
||||||
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist()
|
non_existing_embeddings = embedder(non_existing_texts)
|
||||||
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
|
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
|
||||||
self.embeddings_cache[text] = embedding
|
self.embeddings_cache[text] = embedding
|
||||||
|
|
||||||
@ -145,7 +129,6 @@ class ChromaCollector(Collecter):
|
|||||||
self.id_to_info.update(new_info)
|
self.id_to_info.update(new_info)
|
||||||
self.ids.extend(new_ids)
|
self.ids.extend(new_ids)
|
||||||
|
|
||||||
|
|
||||||
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
|
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
|
||||||
existing_texts, non_existing_texts = [], []
|
existing_texts, non_existing_texts = [], []
|
||||||
existing_embeddings = []
|
existing_embeddings = []
|
||||||
@ -169,7 +152,6 @@ class ChromaCollector(Collecter):
|
|||||||
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||||||
(non_existing_texts, non_existing_ids, non_existing_metas)
|
(non_existing_texts, non_existing_ids, non_existing_metas)
|
||||||
|
|
||||||
|
|
||||||
def _get_new_ids(self, num_new_ids: int):
|
def _get_new_ids(self, num_new_ids: int):
|
||||||
if self.ids:
|
if self.ids:
|
||||||
max_existing_id = max(int(id_) for id_ in self.ids)
|
max_existing_id = max(int(id_) for id_ in self.ids)
|
||||||
@ -178,7 +160,6 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
|
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
|
||||||
|
|
||||||
|
|
||||||
def _find_min_max_start_index(self):
|
def _find_min_max_start_index(self):
|
||||||
max_index, min_index = 0, float('inf')
|
max_index, min_index = 0, float('inf')
|
||||||
for _, val in self.id_to_info.items():
|
for _, val in self.id_to_info.items():
|
||||||
@ -188,13 +169,14 @@ class ChromaCollector(Collecter):
|
|||||||
min_index = val['start_index']
|
min_index = val['start_index']
|
||||||
return min_index, max_index
|
return min_index, max_index
|
||||||
|
|
||||||
|
|
||||||
# NB: Does not make sense to weigh excerpts from different documents.
|
# NB: Does not make sense to weigh excerpts from different documents.
|
||||||
# But let's say that's the user's problem. Perfect world scenario:
|
# But let's say that's the user's problem. Perfect world scenario:
|
||||||
# Apply time weighing to different documents. For each document, then, add
|
# Apply time weighing to different documents. For each document, then, add
|
||||||
# separate time weighing.
|
# separate time weighing.
|
||||||
|
|
||||||
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
|
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
|
||||||
sigmoid = lambda x: 1 / (1 + np.exp(-x))
|
def sigmoid(x):
|
||||||
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
|
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
|
||||||
|
|
||||||
@ -210,7 +192,6 @@ class ChromaCollector(Collecter):
|
|||||||
index = info.start_index
|
index = info.start_index
|
||||||
info.distance *= weights[index]
|
info.distance *= weights[index]
|
||||||
|
|
||||||
|
|
||||||
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
|
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
|
||||||
# Ensure there are infos to filter
|
# Ensure there are infos to filter
|
||||||
if not infos:
|
if not infos:
|
||||||
@ -231,7 +212,6 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
return filtered_infos
|
return filtered_infos
|
||||||
|
|
||||||
|
|
||||||
def _merge_infos(self, infos: list[Info]):
|
def _merge_infos(self, infos: list[Info]):
|
||||||
merged_infos = []
|
merged_infos = []
|
||||||
current_info = infos[0]
|
current_info = infos[0]
|
||||||
@ -247,8 +227,8 @@ class ChromaCollector(Collecter):
|
|||||||
merged_infos.append(current_info)
|
merged_infos.append(current_info)
|
||||||
return merged_infos
|
return merged_infos
|
||||||
|
|
||||||
|
|
||||||
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
|
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
|
||||||
|
|
||||||
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
||||||
n_results = min(len(self.ids), n_results)
|
n_results = min(len(self.ids), n_results)
|
||||||
if n_results == 0:
|
if n_results == 0:
|
||||||
@ -280,22 +260,22 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
return texts_with_context, ids, distances
|
return texts_with_context, ids, distances
|
||||||
|
|
||||||
|
|
||||||
# Get chunks by similarity
|
# Get chunks by similarity
|
||||||
|
|
||||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
|
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
|
||||||
# Get ids by similarity
|
# Get ids by similarity
|
||||||
|
|
||||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
# Cutoff token count
|
# Cutoff token count
|
||||||
|
|
||||||
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
|
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
|
||||||
# TODO: Move to caller; We add delimiters there which might go over the limit.
|
# TODO: Move to caller; We add delimiters there which might go over the limit.
|
||||||
current_token_count = 0
|
current_token_count = 0
|
||||||
@ -318,8 +298,8 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
return return_documents
|
return return_documents
|
||||||
|
|
||||||
|
|
||||||
# Get chunks by similarity and then sort by ids
|
# Get chunks by similarity and then sort by ids
|
||||||
|
|
||||||
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||||
@ -327,20 +307,19 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||||||
|
|
||||||
|
|
||||||
# Get chunks by similarity and then sort by distance (lowest distance is last).
|
# Get chunks by similarity and then sort by distance (lowest distance is last).
|
||||||
|
|
||||||
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
|
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
|
||||||
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
|
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
|
||||||
|
|
||||||
# If a document is truncated or competely skipped, it would be with high distance.
|
# If a document is truncated or competely skipped, it would be with high distance.
|
||||||
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||||||
return_documents.reverse() # highest -> lowest
|
return_documents.reverse() # highest -> lowest
|
||||||
|
|
||||||
return return_documents
|
return return_documents
|
||||||
|
|
||||||
|
|
||||||
def delete(self, ids_to_delete: list[str], where: dict):
|
def delete(self, ids_to_delete: list[str], where: dict):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
|
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
|
||||||
@ -354,23 +333,16 @@ class ChromaCollector(Collecter):
|
|||||||
|
|
||||||
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
|
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
|
||||||
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.chroma_client.reset()
|
self.chroma_client.reset()
|
||||||
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed)
|
|
||||||
self.ids = []
|
self.ids = []
|
||||||
self.id_to_info = {}
|
self.chroma_client.delete_collection(name=self.name)
|
||||||
|
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||||
|
|
||||||
logger.info('Successfully cleared all records and reset chromaDB.')
|
logger.info('Successfully cleared all records and reset chromaDB.')
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformerEmbedder(Embedder):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
logger.debug('Creating Sentence Embedder...')
|
|
||||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
|
||||||
self.embed = self.model.encode
|
|
||||||
|
|
||||||
|
|
||||||
def make_collector():
|
def make_collector():
|
||||||
return ChromaCollector(SentenceTransformerEmbedder())
|
return ChromaCollector()
|
||||||
|
@ -11,20 +11,19 @@ This module contains utils for preprocessing the text before converting it to em
|
|||||||
* removing specific parts of speech (adverbs and interjections)
|
* removing specific parts of speech (adverbs and interjections)
|
||||||
- TextSummarizer extracts the most important sentences from a long string using text-ranking.
|
- TextSummarizer extracts the most important sentences from a long string using text-ranking.
|
||||||
"""
|
"""
|
||||||
import pytextrank
|
|
||||||
import string
|
|
||||||
import spacy
|
|
||||||
import math
|
import math
|
||||||
import nltk
|
|
||||||
import re
|
import re
|
||||||
|
import string
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
import spacy
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from nltk.stem import WordNetLemmatizer
|
from nltk.stem import WordNetLemmatizer
|
||||||
from num2words import num2words
|
from num2words import num2words
|
||||||
|
|
||||||
|
|
||||||
class TextPreprocessorBuilder:
|
class TextPreprocessorBuilder:
|
||||||
# Define class variables as None initially
|
# Define class variables as None initially
|
||||||
_stop_words = set(stopwords.words('english'))
|
_stop_words = set(stopwords.words('english'))
|
||||||
_lemmatizer = WordNetLemmatizer()
|
_lemmatizer = WordNetLemmatizer()
|
||||||
|
|
||||||
@ -32,11 +31,9 @@ class TextPreprocessorBuilder:
|
|||||||
_lemmatizer_cache = {}
|
_lemmatizer_cache = {}
|
||||||
_pos_remove_cache = {}
|
_pos_remove_cache = {}
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, text: str):
|
def __init__(self, text: str):
|
||||||
self.text = text
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
def to_lower(self):
|
def to_lower(self):
|
||||||
# Match both words and non-word characters
|
# Match both words and non-word characters
|
||||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||||
@ -49,7 +46,6 @@ class TextPreprocessorBuilder:
|
|||||||
self.text = "".join(tokens)
|
self.text = "".join(tokens)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def num_to_word(self, min_len: int = 1):
|
def num_to_word(self, min_len: int = 1):
|
||||||
# Match both words and non-word characters
|
# Match both words and non-word characters
|
||||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||||
@ -58,11 +54,10 @@ class TextPreprocessorBuilder:
|
|||||||
if token.isdigit() and len(token) >= min_len:
|
if token.isdigit() and len(token) >= min_len:
|
||||||
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
||||||
# 740700 will become "seven hundred and forty thousand seven hundred".
|
# 740700 will become "seven hundred and forty thousand seven hundred".
|
||||||
tokens[i] = num2words(int(token)).replace(",","") # Remove commas from num2words.
|
tokens[i] = num2words(int(token)).replace(",", "") # Remove commas from num2words.
|
||||||
self.text = "".join(tokens)
|
self.text = "".join(tokens)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def num_to_char_long(self, min_len: int = 1):
|
def num_to_char_long(self, min_len: int = 1):
|
||||||
# Match both words and non-word characters
|
# Match both words and non-word characters
|
||||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||||
@ -71,7 +66,9 @@ class TextPreprocessorBuilder:
|
|||||||
if token.isdigit() and len(token) >= min_len:
|
if token.isdigit() and len(token) >= min_len:
|
||||||
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
||||||
# 740700 will become HHHHHHEEEEEAAAAHHHAAA
|
# 740700 will become HHHHHHEEEEEAAAAHHHAAA
|
||||||
convert_token = lambda token: ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
|
def convert_token(token):
|
||||||
|
return ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
|
||||||
|
|
||||||
tokens[i] = convert_token(tokens[i])
|
tokens[i] = convert_token(tokens[i])
|
||||||
self.text = "".join(tokens)
|
self.text = "".join(tokens)
|
||||||
return self
|
return self
|
||||||
@ -150,6 +147,7 @@ class TextPreprocessorBuilder:
|
|||||||
def build(self):
|
def build(self):
|
||||||
return self.text
|
return self.text
|
||||||
|
|
||||||
|
|
||||||
class TextSummarizer:
|
class TextSummarizer:
|
||||||
_nlp_pipeline = None
|
_nlp_pipeline = None
|
||||||
_cache = {}
|
_cache = {}
|
||||||
|
@ -4,13 +4,14 @@ It will then split it into chunks of specified length. For each of those chunks,
|
|||||||
It will only include full words.
|
It will only include full words.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
|
||||||
import bisect
|
import bisect
|
||||||
|
import re
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
|
||||||
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
|
|
||||||
from .chromadb import ChromaCollector
|
from .chromadb import ChromaCollector
|
||||||
|
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
|
||||||
|
|
||||||
|
|
||||||
def preprocess_text_no_summary(text) -> str:
|
def preprocess_text_no_summary(text) -> str:
|
||||||
builder = TextPreprocessorBuilder(text)
|
builder = TextPreprocessorBuilder(text)
|
||||||
@ -125,7 +126,7 @@ def _clear_chunks(data_chunks, data_chunks_with_context, data_chunk_starting_ind
|
|||||||
seen_chunk_start = seen_chunks.get(chunk)
|
seen_chunk_start = seen_chunks.get(chunk)
|
||||||
if seen_chunk_start:
|
if seen_chunk_start:
|
||||||
# If we've already seen this exact chunk, and the context around it it very close to the seen chunk, then skip it.
|
# If we've already seen this exact chunk, and the context around it it very close to the seen chunk, then skip it.
|
||||||
if abs(seen_chunk_start-index) < parameters.get_delta_start():
|
if abs(seen_chunk_start - index) < parameters.get_delta_start():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
distinct_data_chunks.append(chunk)
|
distinct_data_chunks.append(chunk)
|
||||||
@ -206,4 +207,4 @@ def process_and_add_to_collector(corpus: str, collector: ChromaCollector, clear_
|
|||||||
|
|
||||||
if clear_collector_before_adding:
|
if clear_collector_before_adding:
|
||||||
collector.clear()
|
collector.clear()
|
||||||
collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata]*len(data_chunks) if metadata is not None else None)
|
collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata] * len(data_chunks) if metadata is not None else None)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import requests
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
@ -9,6 +9,7 @@ import extensions.superboogav2.parameters as parameters
|
|||||||
from .data_processor import process_and_add_to_collector
|
from .data_processor import process_and_add_to_collector
|
||||||
from .utils import create_metadata_source
|
from .utils import create_metadata_source
|
||||||
|
|
||||||
|
|
||||||
def _download_single(url):
|
def _download_single(url):
|
||||||
response = requests.get(url, timeout=5)
|
response = requests.get(url, timeout=5)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
@ -4,13 +4,12 @@ This module is responsible for handling and modifying the notebook text.
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from extensions.superboogav2.utils import create_context_text
|
from extensions.superboogav2.utils import create_context_text
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .data_processor import preprocess_text
|
from .data_processor import preprocess_text
|
||||||
|
|
||||||
|
|
||||||
def _remove_special_tokens(string):
|
def _remove_special_tokens(string):
|
||||||
pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
|
pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
|
||||||
return re.sub(pattern, '', string)
|
return re.sub(pattern, '', string)
|
||||||
|
@ -3,22 +3,24 @@ This module implements a hyperparameter optimization routine for the embedding a
|
|||||||
|
|
||||||
Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found.
|
Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found.
|
||||||
"""
|
"""
|
||||||
import re
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import optuna
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import optuna
|
||||||
import hashlib
|
|
||||||
logging.getLogger('optuna').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
logging.getLogger('optuna').setLevel(logging.WARNING)
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .benchmark import benchmark
|
from .benchmark import benchmark
|
||||||
from .parameters import Parameters
|
from .parameters import Parameters
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
|
|
||||||
# Format the parameters into markdown format.
|
# Format the parameters into markdown format.
|
||||||
@ -55,7 +57,7 @@ def _set_hyperparameters(params):
|
|||||||
|
|
||||||
# Check if the parameter is for optimization.
|
# Check if the parameter is for optimization.
|
||||||
def _is_optimization_param(val):
|
def _is_optimization_param(val):
|
||||||
is_opt = val.get('should_optimize', False) # Either does not exist or is false
|
is_opt = val.get('should_optimize', False) # Either does not exist or is false
|
||||||
return is_opt
|
return is_opt
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +69,7 @@ def _get_params_hash(params):
|
|||||||
|
|
||||||
def optimize(collector, progress=gr.Progress()):
|
def optimize(collector, progress=gr.Progress()):
|
||||||
# Inform the user that something is happening.
|
# Inform the user that something is happening.
|
||||||
progress(0, desc=f'Setting Up...')
|
progress(0, desc='Setting Up...')
|
||||||
|
|
||||||
# Track the current step
|
# Track the current step
|
||||||
current_step = 0
|
current_step = 0
|
||||||
|
@ -6,13 +6,11 @@ Each element in the JSON must have a `default` value which will be used for the
|
|||||||
These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`,
|
These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`,
|
||||||
then the optimizer will only ever use the default value.
|
then the optimizer will only ever use the default value.
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
NUM_TO_WORD_METHOD = 'Number to Word'
|
NUM_TO_WORD_METHOD = 'Number to Word'
|
||||||
NUM_TO_CHAR_METHOD = 'Number to Char'
|
NUM_TO_CHAR_METHOD = 'Number to Char'
|
||||||
NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char'
|
NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char'
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
beautifulsoup4==4.12.2
|
beautifulsoup4==4.12.2
|
||||||
chromadb==0.3.18
|
chromadb==0.4.24
|
||||||
lxml
|
lxml
|
||||||
optuna
|
optuna
|
||||||
pandas==2.0.3
|
pandas==2.0.3
|
||||||
|
@ -7,28 +7,29 @@ from pathlib import Path
|
|||||||
# Point to where nltk will find the required data.
|
# Point to where nltk will find the required data.
|
||||||
os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve())
|
os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve())
|
||||||
|
|
||||||
import textwrap
|
|
||||||
import codecs
|
import codecs
|
||||||
|
import textwrap
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .utils import create_metadata_source
|
|
||||||
from .chromadb import make_collector
|
|
||||||
from .download_urls import feed_url_into_collector
|
|
||||||
from .data_processor import process_and_add_to_collector
|
|
||||||
from .benchmark import benchmark
|
|
||||||
from .optimize import optimize
|
|
||||||
from .notebook_handler import input_modifier_internal
|
|
||||||
from .chat_handler import custom_generate_chat_prompt_internal
|
|
||||||
from .api import APIManager
|
from .api import APIManager
|
||||||
|
from .benchmark import benchmark
|
||||||
|
from .chat_handler import custom_generate_chat_prompt_internal
|
||||||
|
from .chromadb import make_collector
|
||||||
|
from .data_processor import process_and_add_to_collector
|
||||||
|
from .download_urls import feed_url_into_collector
|
||||||
|
from .notebook_handler import input_modifier_internal
|
||||||
|
from .optimize import optimize
|
||||||
|
from .utils import create_metadata_source
|
||||||
|
|
||||||
collector = None
|
collector = None
|
||||||
api_manager = None
|
api_manager = None
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
global collector
|
global collector
|
||||||
global api_manager
|
global api_manager
|
||||||
@ -38,6 +39,7 @@ def setup():
|
|||||||
if parameters.get_api_on():
|
if parameters.get_api_on():
|
||||||
api_manager.start_server(parameters.get_api_port())
|
api_manager.start_server(parameters.get_api_port())
|
||||||
|
|
||||||
|
|
||||||
def _feed_data_into_collector(corpus):
|
def _feed_data_into_collector(corpus):
|
||||||
yield '### Processing data...'
|
yield '### Processing data...'
|
||||||
process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text'))
|
process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text'))
|
||||||
@ -305,10 +307,8 @@ def ui():
|
|||||||
optimize_button = gr.Button('Optimize')
|
optimize_button = gr.Button('Optimize')
|
||||||
optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True)
|
optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True)
|
||||||
|
|
||||||
|
|
||||||
clear_button = gr.Button('❌ Clear Data')
|
clear_button = gr.Button('❌ Clear Data')
|
||||||
|
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
last_updated = gr.Markdown()
|
last_updated = gr.Markdown()
|
||||||
|
|
||||||
@ -316,8 +316,7 @@ def ui():
|
|||||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||||
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup]
|
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup]
|
||||||
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||||
preprocess_pipeline, chunk_count, context_len, chunk_len]
|
preprocess_pipeline, chunk_count, context_len, chunk_len]
|
||||||
|
|
||||||
|
|
||||||
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
|
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
|
||||||
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
|
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
|
||||||
@ -326,7 +325,6 @@ def ui():
|
|||||||
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
|
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
|
||||||
clear_button.click(_clear_data, [], last_updated, show_progress=False)
|
clear_button.click(_clear_data, [], last_updated, show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||||
|
@ -4,6 +4,7 @@ This module contains common functions across multiple other modules.
|
|||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
|
|
||||||
|
|
||||||
# Create the context using the prefix + data_separator + postfix from parameters.
|
# Create the context using the prefix + data_separator + postfix from parameters.
|
||||||
def create_context_text(results):
|
def create_context_text(results):
|
||||||
context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix()
|
context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix()
|
||||||
|
Loading…
Reference in New Issue
Block a user