Refactor superbooga

This commit is contained in:
oobabooga 2023-05-13 14:14:59 -03:00
parent 826c74c201
commit 7cc17e3f1f
2 changed files with 90 additions and 85 deletions

View File

@ -0,0 +1,78 @@
import logging
import posthog
import torch
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
logging.info('Intercepting all calls to posthog :)')
posthog.capture = lambda *args, **kwargs: None
class Collecter():
def __init__(self):
pass
def add(self, texts: list[str]):
pass
def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
def clear(self):
pass
class Embedder():
def __init__(self):
pass
def embed(self, text: str) -> list[torch.Tensor]:
pass
class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder):
super().__init__()
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
self.ids = []
def add(self, texts: list[str]):
self.ids = [f"id{i}" for i in range(len(texts))]
self.collection.add(documents=texts, ids=self.ids)
def get(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
return result
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x: int(x[2:]), result))
def clear(self):
self.collection.delete(ids=self.ids)
class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
def make_collector():
global embedder
return ChromaCollector(embedder)
def add_chunks_to_collector(chunks, collector):
collector.clear()
collector.add(chunks)
embedder = SentenceTransformerEmbedder()

View File

@ -2,22 +2,14 @@ import logging
import re import re
import textwrap import textwrap
import chromadb
import gradio as gr import gradio as gr
import posthog
import torch
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from modules import chat, shared from modules import chat, shared
from .chromadb import add_chunks_to_collector, make_collector
from .download_urls import download_urls from .download_urls import download_urls
logging.info('Intercepting all calls to posthog :)')
posthog.capture = lambda *args, **kwargs: None
# These parameters are customizable through settings.json
params = { params = {
'chunk_count': 5, 'chunk_count': 5,
'chunk_length': 700, 'chunk_length': 700,
@ -25,72 +17,11 @@ params = {
'threads': 4, 'threads': 4,
} }
collector = make_collector()
class Collecter(): chat_collector = make_collector()
def __init__(self):
pass
def add(self, texts: list[str]):
pass
def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
def clear(self):
pass
class Embedder():
def __init__(self):
pass
def embed(self, text: str) -> list[torch.Tensor]:
pass
class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder):
super().__init__()
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
self.ids = []
def add(self, texts: list[str]):
self.ids = [f"id{i}" for i in range(len(texts))]
self.collection.add(documents=texts, ids=self.ids)
def get(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
return result
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x: int(x[2:]), result))
def clear(self):
self.collection.delete(ids=self.ids)
class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
embedder = SentenceTransformerEmbedder()
collector = ChromaCollector(embedder)
chat_collector = ChromaCollector(embedder)
chunk_count = 5 chunk_count = 5
def add_chunks_to_collector(chunks, collector):
collector.clear()
collector.add(chunks)
def feed_data_into_collector(corpus, chunk_len): def feed_data_into_collector(corpus, chunk_len):
global collector global collector
@ -150,6 +81,7 @@ def apply_settings(_chunk_count):
settings_to_display = { settings_to_display = {
'chunk_count': chunk_count, 'chunk_count': chunk_count,
} }
yield f"The following settings are now active: {str(settings_to_display)}" yield f"The following settings are now active: {str(settings_to_display)}"
@ -193,10 +125,8 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
def remove_special_tokens(string): def remove_special_tokens(string):
for k in ['<|begin-user-input|>', '<|end-user-input|>', '<|injection-point|>']: pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
string = string.replace(k, '') return re.sub(pattern, '', string)
return string.strip()
def input_modifier(string): def input_modifier(string):
@ -208,17 +138,14 @@ def input_modifier(string):
match = re.search(pattern, string) match = re.search(pattern, string)
if match: if match:
user_input = match.group(1).strip() user_input = match.group(1).strip()
else:
return remove_special_tokens(string)
# Get the most similar chunks # Get the most similar chunks
results = collector.get(user_input, n_results=chunk_count) results = collector.get(user_input, n_results=chunk_count)
# Make the replacements # Make the injection
string = string.replace('<|begin-user-input|>', '').replace('<|end-user-input|>', '') string = string.replace('<|injection-point|>', '\n'.join(results))
string = string.replace('<|injection-point|>', '\n'.join(results))
return string return remove_special_tokens(string)
def ui(): def ui():
@ -250,7 +177,7 @@ def ui():
... ...
``` ```
The injection doesn't make it into the chat history. It is only used in the current generation. The injection doesn't make it into the chat history. It is only used in the current generation.
#### Regular chat #### Regular chat