mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 00:46:53 +01:00
Make superbooga & superboogav2 functional again (#5656)
This commit is contained in:
parent
bae14c8f13
commit
2681f6f640
@ -1,43 +1,24 @@
|
||||
import random
|
||||
|
||||
import chromadb
|
||||
import posthog
|
||||
import torch
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from modules.logging_colors import logger
|
||||
|
||||
logger.info('Intercepting all calls to posthog :)')
|
||||
# Intercept calls to posthog
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class Collecter():
|
||||
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||
|
||||
|
||||
class ChromaCollector():
|
||||
def __init__(self):
|
||||
pass
|
||||
name = ''.join(random.choice('ab') for _ in range(10))
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
pass
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
|
||||
class Embedder():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def embed(self, text: str) -> list[torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class ChromaCollector(Collecter):
|
||||
def __init__(self, embedder: Embedder):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
self.embedder = embedder
|
||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
|
||||
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||
self.ids = []
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
@ -102,24 +83,15 @@ class ChromaCollector(Collecter):
|
||||
return sorted(ids)
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
self.ids = []
|
||||
|
||||
|
||||
class SentenceTransformerEmbedder(Embedder):
|
||||
def __init__(self) -> None:
|
||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||||
self.embed = self.model.encode
|
||||
self.chroma_client.delete_collection(name=self.name)
|
||||
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||
|
||||
|
||||
def make_collector():
|
||||
global embedder
|
||||
return ChromaCollector(embedder)
|
||||
return ChromaCollector()
|
||||
|
||||
|
||||
def add_chunks_to_collector(chunks, collector):
|
||||
collector.clear()
|
||||
collector.add(chunks)
|
||||
|
||||
|
||||
embedder = SentenceTransformerEmbedder()
|
||||
|
@ -1,5 +1,5 @@
|
||||
beautifulsoup4==4.12.2
|
||||
chromadb==0.3.18
|
||||
chromadb==0.4.24
|
||||
pandas==2.0.3
|
||||
posthog==2.4.2
|
||||
sentence_transformers==2.2.2
|
||||
|
@ -12,17 +12,16 @@ This module is responsible for the VectorDB API. It currently supports:
|
||||
|
||||
import json
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
from threading import Thread
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
from .chromadb import ChromaCollector
|
||||
from .data_processor import process_and_add_to_collector
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
|
||||
class CustomThreadingHTTPServer(ThreadingHTTPServer):
|
||||
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
|
||||
@ -38,7 +37,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.collector = collector
|
||||
super().__init__(request, client_address, server)
|
||||
|
||||
|
||||
def _send_412_error(self, message):
|
||||
self.send_response(412)
|
||||
self.send_header("Content-type", "application/json")
|
||||
@ -46,7 +44,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
response = json.dumps({"error": message})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
|
||||
def _send_404_error(self):
|
||||
self.send_response(404)
|
||||
self.send_header("Content-type", "application/json")
|
||||
@ -54,7 +51,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
response = json.dumps({"error": "Resource not found"})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
|
||||
def _send_400_error(self, error_message: str):
|
||||
self.send_response(400)
|
||||
self.send_header("Content-type", "application/json")
|
||||
@ -62,7 +58,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
response = json.dumps({"error": error_message})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
|
||||
def _send_200_response(self, message: str):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
@ -75,7 +70,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
|
||||
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
|
||||
if sort_param == parameters.SORT_DISTANCE:
|
||||
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
|
||||
@ -88,11 +82,9 @@ class Handler(BaseHTTPRequestHandler):
|
||||
"results": results
|
||||
}
|
||||
|
||||
|
||||
def do_GET(self):
|
||||
self._send_404_error()
|
||||
|
||||
|
||||
def do_POST(self):
|
||||
try:
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
@ -146,7 +138,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
except Exception as e:
|
||||
self._send_400_error(str(e))
|
||||
|
||||
|
||||
def do_DELETE(self):
|
||||
try:
|
||||
parsed_path = urlparse(self.path)
|
||||
@ -161,12 +152,10 @@ class Handler(BaseHTTPRequestHandler):
|
||||
except Exception as e:
|
||||
self._send_400_error(str(e))
|
||||
|
||||
|
||||
def do_OPTIONS(self):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
|
||||
|
||||
def end_headers(self):
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', '*')
|
||||
@ -197,7 +186,7 @@ class APIManager:
|
||||
|
||||
def stop_server(self):
|
||||
if self.server is not None:
|
||||
logger.info(f'Stopping chromaDB API.')
|
||||
logger.info('Stopping chromaDB API.')
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
self.server = None
|
||||
|
@ -9,13 +9,13 @@ The benchmark function will return the score as an integer.
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .data_processor import process_and_add_to_collector, preprocess_text
|
||||
from .data_processor import preprocess_text, process_and_add_to_collector
|
||||
from .parameters import get_chunk_count, get_max_token_count
|
||||
from .utils import create_metadata_source
|
||||
|
||||
|
||||
def benchmark(config_path, collector):
|
||||
# Get the current system date
|
||||
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
@ -4,16 +4,17 @@ This module is responsible for modifying the chat prompt and history.
|
||||
import re
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
from extensions.superboogav2.utils import (
|
||||
create_context_text,
|
||||
create_metadata_source
|
||||
)
|
||||
from modules import chat, shared
|
||||
from modules.text_generation import get_encoded_length
|
||||
from modules.logging_colors import logger
|
||||
from modules.chat import load_character_memoized
|
||||
from extensions.superboogav2.utils import create_context_text, create_metadata_source
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import get_encoded_length
|
||||
|
||||
from .data_processor import process_and_add_to_collector
|
||||
from .chromadb import ChromaCollector
|
||||
|
||||
from .data_processor import process_and_add_to_collector
|
||||
|
||||
CHAT_METADATA = create_metadata_source('automatic-chat-insert')
|
||||
|
||||
|
@ -1,42 +1,23 @@
|
||||
import threading
|
||||
import chromadb
|
||||
import posthog
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
import threading
|
||||
|
||||
import chromadb
|
||||
import numpy as np
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
import posthog
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import encode, decode
|
||||
from modules.text_generation import decode, encode
|
||||
|
||||
logger.debug('Intercepting all calls to posthog.')
|
||||
# Intercept calls to posthog
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class Collecter():
|
||||
def __init__(self):
|
||||
pass
|
||||
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||
|
||||
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]):
|
||||
pass
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
|
||||
class Embedder():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def embed(self, text: str) -> list[torch.Tensor]:
|
||||
pass
|
||||
|
||||
class Info:
|
||||
def __init__(self, start_index, text_with_context, distance, id):
|
||||
@ -93,12 +74,15 @@ class Info:
|
||||
|
||||
return not (s1_end < s2_start or s2_end < s1_start)
|
||||
|
||||
class ChromaCollector(Collecter):
|
||||
def __init__(self, embedder: Embedder):
|
||||
super().__init__()
|
||||
|
||||
class ChromaCollector():
|
||||
def __init__(self):
|
||||
name = ''.join(random.choice('ab') for _ in range(10))
|
||||
|
||||
self.name = name
|
||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
self.embedder = embedder
|
||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed)
|
||||
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||
|
||||
self.ids = []
|
||||
self.id_to_info = {}
|
||||
self.embeddings_cache = {}
|
||||
@ -126,7 +110,7 @@ class ChromaCollector(Collecter):
|
||||
|
||||
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
|
||||
if non_existing_texts:
|
||||
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist()
|
||||
non_existing_embeddings = embedder(non_existing_texts)
|
||||
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
|
||||
self.embeddings_cache[text] = embedding
|
||||
|
||||
@ -145,7 +129,6 @@ class ChromaCollector(Collecter):
|
||||
self.id_to_info.update(new_info)
|
||||
self.ids.extend(new_ids)
|
||||
|
||||
|
||||
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
|
||||
existing_texts, non_existing_texts = [], []
|
||||
existing_embeddings = []
|
||||
@ -169,7 +152,6 @@ class ChromaCollector(Collecter):
|
||||
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
|
||||
(non_existing_texts, non_existing_ids, non_existing_metas)
|
||||
|
||||
|
||||
def _get_new_ids(self, num_new_ids: int):
|
||||
if self.ids:
|
||||
max_existing_id = max(int(id_) for id_ in self.ids)
|
||||
@ -178,7 +160,6 @@ class ChromaCollector(Collecter):
|
||||
|
||||
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
|
||||
|
||||
|
||||
def _find_min_max_start_index(self):
|
||||
max_index, min_index = 0, float('inf')
|
||||
for _, val in self.id_to_info.items():
|
||||
@ -188,13 +169,14 @@ class ChromaCollector(Collecter):
|
||||
min_index = val['start_index']
|
||||
return min_index, max_index
|
||||
|
||||
|
||||
# NB: Does not make sense to weigh excerpts from different documents.
|
||||
# But let's say that's the user's problem. Perfect world scenario:
|
||||
# Apply time weighing to different documents. For each document, then, add
|
||||
# separate time weighing.
|
||||
|
||||
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
|
||||
sigmoid = lambda x: 1 / (1 + np.exp(-x))
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
|
||||
|
||||
@ -210,7 +192,6 @@ class ChromaCollector(Collecter):
|
||||
index = info.start_index
|
||||
info.distance *= weights[index]
|
||||
|
||||
|
||||
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
|
||||
# Ensure there are infos to filter
|
||||
if not infos:
|
||||
@ -231,7 +212,6 @@ class ChromaCollector(Collecter):
|
||||
|
||||
return filtered_infos
|
||||
|
||||
|
||||
def _merge_infos(self, infos: list[Info]):
|
||||
merged_infos = []
|
||||
current_info = infos[0]
|
||||
@ -247,8 +227,8 @@ class ChromaCollector(Collecter):
|
||||
merged_infos.append(current_info)
|
||||
return merged_infos
|
||||
|
||||
|
||||
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
|
||||
|
||||
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
||||
n_results = min(len(self.ids), n_results)
|
||||
if n_results == 0:
|
||||
@ -280,22 +260,22 @@ class ChromaCollector(Collecter):
|
||||
|
||||
return texts_with_context, ids, distances
|
||||
|
||||
|
||||
# Get chunks by similarity
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
with self.lock:
|
||||
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||
return documents
|
||||
|
||||
|
||||
# Get ids by similarity
|
||||
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
with self.lock:
|
||||
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||
return ids
|
||||
|
||||
|
||||
# Cutoff token count
|
||||
|
||||
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
|
||||
# TODO: Move to caller; We add delimiters there which might go over the limit.
|
||||
current_token_count = 0
|
||||
@ -318,8 +298,8 @@ class ChromaCollector(Collecter):
|
||||
|
||||
return return_documents
|
||||
|
||||
|
||||
# Get chunks by similarity and then sort by ids
|
||||
|
||||
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||||
with self.lock:
|
||||
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
|
||||
@ -327,8 +307,8 @@ class ChromaCollector(Collecter):
|
||||
|
||||
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
|
||||
|
||||
|
||||
# Get chunks by similarity and then sort by distance (lowest distance is last).
|
||||
|
||||
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
|
||||
with self.lock:
|
||||
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
|
||||
@ -340,7 +320,6 @@ class ChromaCollector(Collecter):
|
||||
|
||||
return return_documents
|
||||
|
||||
|
||||
def delete(self, ids_to_delete: list[str], where: dict):
|
||||
with self.lock:
|
||||
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
|
||||
@ -354,23 +333,16 @@ class ChromaCollector(Collecter):
|
||||
|
||||
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
|
||||
|
||||
|
||||
def clear(self):
|
||||
with self.lock:
|
||||
self.chroma_client.reset()
|
||||
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed)
|
||||
|
||||
self.ids = []
|
||||
self.id_to_info = {}
|
||||
self.chroma_client.delete_collection(name=self.name)
|
||||
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||
|
||||
logger.info('Successfully cleared all records and reset chromaDB.')
|
||||
|
||||
|
||||
class SentenceTransformerEmbedder(Embedder):
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Creating Sentence Embedder...')
|
||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||||
self.embed = self.model.encode
|
||||
|
||||
|
||||
def make_collector():
|
||||
return ChromaCollector(SentenceTransformerEmbedder())
|
||||
return ChromaCollector()
|
||||
|
@ -11,13 +11,12 @@ This module contains utils for preprocessing the text before converting it to em
|
||||
* removing specific parts of speech (adverbs and interjections)
|
||||
- TextSummarizer extracts the most important sentences from a long string using text-ranking.
|
||||
"""
|
||||
import pytextrank
|
||||
import string
|
||||
import spacy
|
||||
import math
|
||||
import nltk
|
||||
import re
|
||||
import string
|
||||
|
||||
import nltk
|
||||
import spacy
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
from num2words import num2words
|
||||
@ -32,11 +31,9 @@ class TextPreprocessorBuilder:
|
||||
_lemmatizer_cache = {}
|
||||
_pos_remove_cache = {}
|
||||
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
|
||||
def to_lower(self):
|
||||
# Match both words and non-word characters
|
||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||
@ -49,7 +46,6 @@ class TextPreprocessorBuilder:
|
||||
self.text = "".join(tokens)
|
||||
return self
|
||||
|
||||
|
||||
def num_to_word(self, min_len: int = 1):
|
||||
# Match both words and non-word characters
|
||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||
@ -62,7 +58,6 @@ class TextPreprocessorBuilder:
|
||||
self.text = "".join(tokens)
|
||||
return self
|
||||
|
||||
|
||||
def num_to_char_long(self, min_len: int = 1):
|
||||
# Match both words and non-word characters
|
||||
tokens = re.findall(r'\b\w+\b|\W+', self.text)
|
||||
@ -71,7 +66,9 @@ class TextPreprocessorBuilder:
|
||||
if token.isdigit() and len(token) >= min_len:
|
||||
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
|
||||
# 740700 will become HHHHHHEEEEEAAAAHHHAAA
|
||||
convert_token = lambda token: ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
|
||||
def convert_token(token):
|
||||
return ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
|
||||
|
||||
tokens[i] = convert_token(tokens[i])
|
||||
self.text = "".join(tokens)
|
||||
return self
|
||||
@ -150,6 +147,7 @@ class TextPreprocessorBuilder:
|
||||
def build(self):
|
||||
return self.text
|
||||
|
||||
|
||||
class TextSummarizer:
|
||||
_nlp_pipeline = None
|
||||
_cache = {}
|
||||
|
@ -4,13 +4,14 @@ It will then split it into chunks of specified length. For each of those chunks,
|
||||
It will only include full words.
|
||||
"""
|
||||
|
||||
import re
|
||||
import bisect
|
||||
import re
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
|
||||
from .chromadb import ChromaCollector
|
||||
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
|
||||
|
||||
|
||||
def preprocess_text_no_summary(text) -> str:
|
||||
builder = TextPreprocessorBuilder(text)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import concurrent.futures
|
||||
import requests
|
||||
import re
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
@ -9,6 +9,7 @@ import extensions.superboogav2.parameters as parameters
|
||||
from .data_processor import process_and_add_to_collector
|
||||
from .utils import create_metadata_source
|
||||
|
||||
|
||||
def _download_single(url):
|
||||
response = requests.get(url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
|
@ -4,13 +4,12 @@ This module is responsible for handling and modifying the notebook text.
|
||||
import re
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from extensions.superboogav2.utils import create_context_text
|
||||
from modules.logging_colors import logger
|
||||
|
||||
from .data_processor import preprocess_text
|
||||
|
||||
|
||||
def _remove_special_tokens(string):
|
||||
pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
|
||||
return re.sub(pattern, '', string)
|
||||
|
@ -3,22 +3,24 @@ This module implements a hyperparameter optimization routine for the embedding a
|
||||
|
||||
Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found.
|
||||
"""
|
||||
import re
|
||||
import hashlib
|
||||
import json
|
||||
import optuna
|
||||
import logging
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import logging
|
||||
import hashlib
|
||||
logging.getLogger('optuna').setLevel(logging.WARNING)
|
||||
import optuna
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
logging.getLogger('optuna').setLevel(logging.WARNING)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
from modules.logging_colors import logger
|
||||
|
||||
from .benchmark import benchmark
|
||||
from .parameters import Parameters
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
# Format the parameters into markdown format.
|
||||
@ -67,7 +69,7 @@ def _get_params_hash(params):
|
||||
|
||||
def optimize(collector, progress=gr.Progress()):
|
||||
# Inform the user that something is happening.
|
||||
progress(0, desc=f'Setting Up...')
|
||||
progress(0, desc='Setting Up...')
|
||||
|
||||
# Track the current step
|
||||
current_step = 0
|
||||
|
@ -6,13 +6,11 @@ Each element in the JSON must have a `default` value which will be used for the
|
||||
These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`,
|
||||
then the optimizer will only ever use the default value.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import json
|
||||
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
NUM_TO_WORD_METHOD = 'Number to Word'
|
||||
NUM_TO_CHAR_METHOD = 'Number to Char'
|
||||
NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char'
|
||||
|
@ -1,5 +1,5 @@
|
||||
beautifulsoup4==4.12.2
|
||||
chromadb==0.3.18
|
||||
chromadb==0.4.24
|
||||
lxml
|
||||
optuna
|
||||
pandas==2.0.3
|
||||
|
@ -7,28 +7,29 @@ from pathlib import Path
|
||||
# Point to where nltk will find the required data.
|
||||
os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve())
|
||||
|
||||
import textwrap
|
||||
import codecs
|
||||
import textwrap
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
from modules.logging_colors import logger
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
from .utils import create_metadata_source
|
||||
from .chromadb import make_collector
|
||||
from .download_urls import feed_url_into_collector
|
||||
from .data_processor import process_and_add_to_collector
|
||||
from .benchmark import benchmark
|
||||
from .optimize import optimize
|
||||
from .notebook_handler import input_modifier_internal
|
||||
from .chat_handler import custom_generate_chat_prompt_internal
|
||||
from .api import APIManager
|
||||
from .benchmark import benchmark
|
||||
from .chat_handler import custom_generate_chat_prompt_internal
|
||||
from .chromadb import make_collector
|
||||
from .data_processor import process_and_add_to_collector
|
||||
from .download_urls import feed_url_into_collector
|
||||
from .notebook_handler import input_modifier_internal
|
||||
from .optimize import optimize
|
||||
from .utils import create_metadata_source
|
||||
|
||||
collector = None
|
||||
api_manager = None
|
||||
|
||||
|
||||
def setup():
|
||||
global collector
|
||||
global api_manager
|
||||
@ -38,6 +39,7 @@ def setup():
|
||||
if parameters.get_api_on():
|
||||
api_manager.start_server(parameters.get_api_port())
|
||||
|
||||
|
||||
def _feed_data_into_collector(corpus):
|
||||
yield '### Processing data...'
|
||||
process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text'))
|
||||
@ -305,10 +307,8 @@ def ui():
|
||||
optimize_button = gr.Button('Optimize')
|
||||
optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True)
|
||||
|
||||
|
||||
clear_button = gr.Button('❌ Clear Data')
|
||||
|
||||
|
||||
with gr.Column():
|
||||
last_updated = gr.Markdown()
|
||||
|
||||
@ -318,7 +318,6 @@ def ui():
|
||||
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, chunk_count, context_len, chunk_len]
|
||||
|
||||
|
||||
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
|
||||
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
|
||||
update_file.click(_feed_file_into_collector, [file_input], last_updated, show_progress=False)
|
||||
@ -326,7 +325,6 @@ def ui():
|
||||
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
|
||||
clear_button.click(_clear_data, [], last_updated, show_progress=False)
|
||||
|
||||
|
||||
optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
|
@ -4,6 +4,7 @@ This module contains common functions across multiple other modules.
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
|
||||
# Create the context using the prefix + data_separator + postfix from parameters.
|
||||
def create_context_text(results):
|
||||
context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix()
|
||||
|
Loading…
Reference in New Issue
Block a user