2023-05-07 21:50:45 +02:00
import logging
2023-05-07 08:50:12 +02:00
import re
import textwrap
2023-05-07 09:49:02 +02:00
from urllib . request import urlopen
2023-05-07 08:50:12 +02:00
import chromadb
import gradio as gr
import posthog
import torch
2023-05-07 09:49:02 +02:00
from bs4 import BeautifulSoup
2023-05-07 08:50:12 +02:00
from chromadb . config import Settings
from sentence_transformers import SentenceTransformer
2023-05-07 20:01:14 +02:00
from modules import chat , shared
2023-05-07 21:50:45 +02:00
logging . info ( ' Intercepting all calls to posthog :) ' )
2023-05-07 08:50:12 +02:00
posthog . capture = lambda * args , * * kwargs : None
class Collecter ( ) :
def __init__ ( self ) :
pass
def add ( self , texts : list [ str ] ) :
pass
def get ( self , search_strings : list [ str ] , n_results : int ) - > list [ str ] :
pass
def clear ( self ) :
pass
class Embedder ( ) :
def __init__ ( self ) :
pass
def embed ( self , text : str ) - > list [ torch . Tensor ] :
pass
class ChromaCollector ( Collecter ) :
def __init__ ( self , embedder : Embedder ) :
super ( ) . __init__ ( )
self . chroma_client = chromadb . Client ( Settings ( anonymized_telemetry = False ) )
self . embedder = embedder
self . collection = self . chroma_client . create_collection ( name = " context " , embedding_function = embedder . embed )
self . ids = [ ]
def add ( self , texts : list [ str ] ) :
self . ids = [ f " id { i } " for i in range ( len ( texts ) ) ]
self . collection . add ( documents = texts , ids = self . ids )
def get ( self , search_strings : list [ str ] , n_results : int ) - > list [ str ] :
result = self . collection . query ( query_texts = search_strings , n_results = n_results , include = [ ' documents ' ] ) [ ' documents ' ] [ 0 ]
return result
2023-05-07 20:01:14 +02:00
def get_ids ( self , search_strings : list [ str ] , n_results : int ) - > list [ str ] :
result = self . collection . query ( query_texts = search_strings , n_results = n_results , include = [ ' documents ' ] ) [ ' ids ' ] [ 0 ]
return list ( map ( lambda x : int ( x [ 2 : ] ) , result ) )
2023-05-07 08:50:12 +02:00
def clear ( self ) :
self . collection . delete ( ids = self . ids )
class SentenceTransformerEmbedder ( Embedder ) :
def __init__ ( self ) - > None :
self . model = SentenceTransformer ( " sentence-transformers/all-mpnet-base-v2 " )
self . embed = self . model . encode
embedder = SentenceTransformerEmbedder ( )
collector = ChromaCollector ( embedder )
2023-05-07 10:02:04 +02:00
chunk_count = 5
2023-05-07 08:50:12 +02:00
2023-05-07 20:01:14 +02:00
def add_chunks_to_collector ( chunks ) :
2023-05-07 16:30:16 +02:00
global collector
2023-05-07 20:01:14 +02:00
collector . clear ( )
collector . add ( chunks )
2023-05-07 09:49:02 +02:00
2023-05-07 20:01:14 +02:00
def feed_data_into_collector ( corpus , chunk_len ) :
# Defining variables
chunk_len = int ( chunk_len )
2023-05-07 09:49:02 +02:00
cumulative = ' '
2023-05-07 20:01:14 +02:00
# Breaking the data into chunks and adding those to the db
2023-05-07 09:49:02 +02:00
cumulative + = " Breaking the input dataset... \n \n "
yield cumulative
2023-05-07 08:50:12 +02:00
data_chunks = [ corpus [ i : i + chunk_len ] for i in range ( 0 , len ( corpus ) , chunk_len ) ]
2023-05-07 09:49:02 +02:00
cumulative + = f " { len ( data_chunks ) } chunks have been found. \n \n Adding the chunks to the database... \n \n "
yield cumulative
2023-05-07 20:01:14 +02:00
add_chunks_to_collector ( data_chunks )
2023-05-07 09:49:02 +02:00
cumulative + = " Done. "
yield cumulative
2023-05-07 16:30:16 +02:00
def feed_file_into_collector ( file , chunk_len ) :
2023-05-07 09:49:02 +02:00
yield ' Reading the input dataset... \n \n '
text = file . decode ( ' utf-8 ' )
2023-05-07 16:30:16 +02:00
for i in feed_data_into_collector ( text , chunk_len ) :
2023-05-07 09:49:02 +02:00
yield i
2023-05-07 16:30:16 +02:00
def feed_url_into_collector ( urls , chunk_len ) :
2023-05-07 16:07:16 +02:00
urls = urls . strip ( ) . split ( ' \n ' )
all_text = ' '
cumulative = ' '
for url in urls :
cumulative + = f ' Loading { url } ... \n \n '
yield cumulative
html = urlopen ( url ) . read ( )
soup = BeautifulSoup ( html , features = " html.parser " )
for script in soup ( [ " script " , " style " ] ) :
script . extract ( )
text = soup . get_text ( )
lines = ( line . strip ( ) for line in text . splitlines ( ) )
chunks = ( phrase . strip ( ) for line in lines for phrase in line . split ( " " ) )
text = ' \n \n ' . join ( chunk for chunk in chunks if chunk )
all_text + = text
2023-05-07 16:30:16 +02:00
for i in feed_data_into_collector ( all_text , chunk_len ) :
2023-05-07 09:49:02 +02:00
yield i
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
def apply_settings ( _chunk_count ) :
global chunk_count
2023-05-07 21:25:39 +02:00
chunk_count = int ( _chunk_count )
2023-05-07 16:30:16 +02:00
settings_to_display = {
2023-05-07 21:25:39 +02:00
' chunk_count ' : chunk_count ,
2023-05-07 16:30:16 +02:00
}
yield f " The following settings are now active: { str ( settings_to_display ) } "
2023-05-07 08:50:12 +02:00
def input_modifier ( string ) :
2023-05-07 20:01:14 +02:00
if shared . is_chat ( ) :
return string
2023-05-07 08:50:12 +02:00
# Find the user input
2023-05-07 16:54:26 +02:00
pattern = re . compile ( r " < \ |begin-user-input \ |>(.*?)< \ |end-user-input \ |> " , re . DOTALL )
2023-05-07 08:50:12 +02:00
match = re . search ( pattern , string )
if match :
2023-05-07 16:54:26 +02:00
user_input = match . group ( 1 ) . strip ( )
2023-05-07 08:50:12 +02:00
else :
user_input = ' '
2023-05-07 10:02:04 +02:00
# Get the most similar chunks
results = collector . get ( user_input , n_results = chunk_count )
2023-05-07 08:50:12 +02:00
# Make the replacements
string = string . replace ( ' <|begin-user-input|> ' , ' ' )
string = string . replace ( ' <|end-user-input|> ' , ' ' )
string = string . replace ( ' <|injection-point|> ' , ' \n ' . join ( results ) )
return string
2023-05-07 20:01:14 +02:00
def custom_generate_chat_prompt ( user_input , state , * * kwargs ) :
if len ( shared . history [ ' internal ' ] ) > 2 and user_input != ' ' :
chunks = [ ]
for i in range ( len ( shared . history [ ' internal ' ] ) - 1 ) :
chunks . append ( ' \n ' . join ( shared . history [ ' internal ' ] [ i ] ) )
add_chunks_to_collector ( chunks )
query = ' \n ' . join ( shared . history [ ' internal ' ] [ - 1 ] + [ user_input ] )
2023-05-07 21:50:45 +02:00
try :
best_ids = collector . get_ids ( query , n_results = len ( shared . history [ ' internal ' ] ) - 1 )
# Sort the history by relevance instead of by chronological order,
# except for the latest message
state [ ' history ' ] = [ shared . history [ ' internal ' ] [ id_ ] for id_ in best_ids [ : : - 1 ] ] + [ shared . history [ ' internal ' ] [ - 1 ] ]
except RuntimeError :
logging . error ( " Couldn ' t query the database, moving on... " )
2023-05-07 20:01:14 +02:00
return chat . generate_chat_prompt ( user_input , state , * * kwargs )
2023-05-07 08:50:12 +02:00
def ui ( ) :
2023-05-07 16:30:16 +02:00
with gr . Accordion ( " Click for more information... " , open = False ) :
gr . Markdown ( textwrap . dedent ( """
## About
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
This extension takes a dataset as input , breaks it into chunks , and adds the result to a local / offline Chroma database .
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
The database is then queried during inference time to get the excerpts that are closest to your input . The idea is to create
an arbitrarily large pseudocontext .
2023-05-07 08:50:12 +02:00
2023-05-07 18:29:49 +02:00
It is a modified version of the superbig extension by kaiokendev : https : / / github . com / kaiokendev / superbig
2023-05-07 20:01:14 +02:00
## Notebook/default modes
### How to use it
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
1 ) Paste your input text ( of whatever length ) into the text box below .
2 ) Click on " Load data " to feed this text into the Chroma database .
3 ) In your prompt , enter your question between ` < | begin - user - input | > ` and ` < | end - user - input | > ` , and specify the injection point with ` < | injection - point | > ` .
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
By default , the 5 closest chunks will be injected . You can customize this value in the " Generation settings " tab .
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
The special tokens mentioned above ( ` < | begin - user - input | > ` , ` < | end - user - input | > ` , and ` < | injection - point | > ` ) are removed when the injection happens .
2023-05-07 08:50:12 +02:00
2023-05-07 20:01:14 +02:00
### Example
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
For your convenience , you can use the following prompt as a starting point ( for Alpaca models ) :
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
` ` `
Below is an instruction that describes a task , paired with an input that provides further context . Write a response that appropriately completes the request .
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
### Instruction:
You are ArxivGPT , trained on millions of Arxiv papers . You always answer the question , even if full context isn ' t provided to you. The following are snippets from an Arxiv paper. Use the snippets to answer the question. Think about it step by step
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
< | injection - point | >
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
### Input:
2023-05-07 16:54:26 +02:00
< | begin - user - input | >
What datasets are mentioned in the paper above ?
< | end - user - input | >
2023-05-07 16:30:16 +02:00
### Response:
` ` `
2023-05-07 20:01:14 +02:00
## Chat mode
In chat mode , the extension automatically sorts the history by relevance instead of chronologically , except for the very latest input / reply pair .
That is , the prompt will include ( starting from the end ) :
* Your input
* The latest input / reply pair
* The #1 most relevant input/reply pair prior to the latest
* The #2 most relevant input/reply pair prior to the latest
* Etc
This way , the bot can have a long term history .
2023-05-07 16:30:16 +02:00
* This extension is currently experimental and under development . *
""" ))
2023-05-07 08:50:12 +02:00
2023-05-07 20:01:14 +02:00
if not shared . is_chat ( ) :
2023-05-07 08:50:12 +02:00
with gr . Row ( ) :
2023-05-07 09:49:02 +02:00
with gr . Column ( ) :
with gr . Tab ( " Text input " ) :
data_input = gr . Textbox ( lines = 20 , label = ' Input data ' )
2023-05-07 16:30:16 +02:00
update_data = gr . Button ( ' Load data ' )
2023-05-07 09:49:02 +02:00
with gr . Tab ( " URL input " ) :
2023-05-07 16:30:16 +02:00
url_input = gr . Textbox ( lines = 10 , label = ' Input URLs ' , info = ' Enter one or more URLs separated by newline characters. ' )
update_url = gr . Button ( ' Load data ' )
2023-05-07 09:49:02 +02:00
with gr . Tab ( " File input " ) :
file_input = gr . File ( label = ' Input file ' , type = ' binary ' )
2023-05-07 16:30:16 +02:00
update_file = gr . Button ( ' Load data ' )
with gr . Tab ( " Generation settings " ) :
chunk_count = gr . Number ( value = 5 , label = ' Chunk count ' , info = ' The number of closest-matching chunks to include in the prompt. ' )
update_settings = gr . Button ( ' Apply changes ' )
2023-05-07 09:49:02 +02:00
2023-05-07 16:30:16 +02:00
chunk_len = gr . Number ( value = 700 , label = ' Chunk length ' , info = ' In characters, not tokens. This value is used when you click on " Load data " . ' )
2023-05-07 10:02:04 +02:00
2023-05-07 09:49:02 +02:00
with gr . Column ( ) :
last_updated = gr . Markdown ( )
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
update_data . click ( feed_data_into_collector , [ data_input , chunk_len ] , last_updated , show_progress = False )
update_url . click ( feed_url_into_collector , [ url_input , chunk_len ] , last_updated , show_progress = False )
update_file . click ( feed_file_into_collector , [ file_input , chunk_len ] , last_updated , show_progress = False )
update_settings . click ( apply_settings , [ chunk_count ] , last_updated , show_progress = False )