mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Refactor superbooga
This commit is contained in:
parent
826c74c201
commit
7cc17e3f1f
78
extensions/superbooga/chromadb.py
Normal file
78
extensions/superbooga/chromadb.py
Normal file
@ -0,0 +1,78 @@
|
||||
import logging
|
||||
|
||||
import posthog
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
logging.info('Intercepting all calls to posthog :)')
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class Collecter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
pass
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
|
||||
class Embedder():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def embed(self, text: str) -> list[torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class ChromaCollector(Collecter):
|
||||
def __init__(self, embedder: Embedder):
|
||||
super().__init__()
|
||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
self.embedder = embedder
|
||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
|
||||
self.ids = []
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
self.ids = [f"id{i}" for i in range(len(texts))]
|
||||
self.collection.add(documents=texts, ids=self.ids)
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
n_results = min(len(self.ids), n_results)
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
||||
return result
|
||||
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
n_results = min(len(self.ids), n_results)
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||
return list(map(lambda x: int(x[2:]), result))
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
|
||||
|
||||
class SentenceTransformerEmbedder(Embedder):
|
||||
def __init__(self) -> None:
|
||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||||
self.embed = self.model.encode
|
||||
|
||||
|
||||
def make_collector():
|
||||
global embedder
|
||||
return ChromaCollector(embedder)
|
||||
|
||||
|
||||
def add_chunks_to_collector(chunks, collector):
|
||||
collector.clear()
|
||||
collector.add(chunks)
|
||||
|
||||
|
||||
embedder = SentenceTransformerEmbedder()
|
@ -2,22 +2,14 @@ import logging
|
||||
import re
|
||||
import textwrap
|
||||
|
||||
import chromadb
|
||||
import gradio as gr
|
||||
import posthog
|
||||
import torch
|
||||
from bs4 import BeautifulSoup
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from modules import chat, shared
|
||||
|
||||
from .chromadb import add_chunks_to_collector, make_collector
|
||||
from .download_urls import download_urls
|
||||
|
||||
logging.info('Intercepting all calls to posthog :)')
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
|
||||
# These parameters are customizable through settings.json
|
||||
params = {
|
||||
'chunk_count': 5,
|
||||
'chunk_length': 700,
|
||||
@ -25,72 +17,11 @@ params = {
|
||||
'threads': 4,
|
||||
}
|
||||
|
||||
|
||||
class Collecter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
pass
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
|
||||
class Embedder():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def embed(self, text: str) -> list[torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class ChromaCollector(Collecter):
|
||||
def __init__(self, embedder: Embedder):
|
||||
super().__init__()
|
||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
self.embedder = embedder
|
||||
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
|
||||
self.ids = []
|
||||
|
||||
def add(self, texts: list[str]):
|
||||
self.ids = [f"id{i}" for i in range(len(texts))]
|
||||
self.collection.add(documents=texts, ids=self.ids)
|
||||
|
||||
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
n_results = min(len(self.ids), n_results)
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
||||
return result
|
||||
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
n_results = min(len(self.ids), n_results)
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||
return list(map(lambda x: int(x[2:]), result))
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
|
||||
|
||||
class SentenceTransformerEmbedder(Embedder):
|
||||
def __init__(self) -> None:
|
||||
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
||||
self.embed = self.model.encode
|
||||
|
||||
|
||||
embedder = SentenceTransformerEmbedder()
|
||||
collector = ChromaCollector(embedder)
|
||||
chat_collector = ChromaCollector(embedder)
|
||||
collector = make_collector()
|
||||
chat_collector = make_collector()
|
||||
chunk_count = 5
|
||||
|
||||
|
||||
def add_chunks_to_collector(chunks, collector):
|
||||
collector.clear()
|
||||
collector.add(chunks)
|
||||
|
||||
|
||||
def feed_data_into_collector(corpus, chunk_len):
|
||||
global collector
|
||||
|
||||
@ -150,6 +81,7 @@ def apply_settings(_chunk_count):
|
||||
settings_to_display = {
|
||||
'chunk_count': chunk_count,
|
||||
}
|
||||
|
||||
yield f"The following settings are now active: {str(settings_to_display)}"
|
||||
|
||||
|
||||
@ -193,10 +125,8 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
|
||||
|
||||
def remove_special_tokens(string):
|
||||
for k in ['<|begin-user-input|>', '<|end-user-input|>', '<|injection-point|>']:
|
||||
string = string.replace(k, '')
|
||||
|
||||
return string.strip()
|
||||
pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
|
||||
return re.sub(pattern, '', string)
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
@ -208,17 +138,14 @@ def input_modifier(string):
|
||||
match = re.search(pattern, string)
|
||||
if match:
|
||||
user_input = match.group(1).strip()
|
||||
else:
|
||||
return remove_special_tokens(string)
|
||||
|
||||
# Get the most similar chunks
|
||||
results = collector.get(user_input, n_results=chunk_count)
|
||||
# Get the most similar chunks
|
||||
results = collector.get(user_input, n_results=chunk_count)
|
||||
|
||||
# Make the replacements
|
||||
string = string.replace('<|begin-user-input|>', '').replace('<|end-user-input|>', '')
|
||||
string = string.replace('<|injection-point|>', '\n'.join(results))
|
||||
# Make the injection
|
||||
string = string.replace('<|injection-point|>', '\n'.join(results))
|
||||
|
||||
return string
|
||||
return remove_special_tokens(string)
|
||||
|
||||
|
||||
def ui():
|
||||
@ -250,7 +177,7 @@ def ui():
|
||||
...
|
||||
```
|
||||
|
||||
The injection doesn't make it into the chat history. It is only used in the current generation.
|
||||
The injection doesn't make it into the chat history. It is only used in the current generation.
|
||||
|
||||
#### Regular chat
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user