2023-05-07 08:50:12 +02:00
import re
import textwrap
import gradio as gr
2023-05-07 09:49:02 +02:00
from bs4 import BeautifulSoup
2023-05-20 23:42:17 +02:00
2023-05-07 20:01:14 +02:00
from modules import chat , shared
2023-05-22 03:42:34 +02:00
from modules . logging_colors import logger
2023-05-07 20:01:14 +02:00
2023-05-13 19:14:59 +02:00
from . chromadb import add_chunks_to_collector , make_collector
2023-05-12 19:19:55 +02:00
from . download_urls import download_urls
2023-05-11 04:23:37 +02:00
params = {
' chunk_count ' : 5 ,
2023-05-25 15:22:45 +02:00
' chunk_count_initial ' : 10 ,
' time_weight ' : 0 ,
2023-05-11 04:54:25 +02:00
' chunk_length ' : 700 ,
2023-05-15 02:44:52 +02:00
' chunk_separator ' : ' ' ,
2023-05-13 17:50:19 +02:00
' strong_cleanup ' : False ,
2023-05-12 19:19:55 +02:00
' threads ' : 4 ,
2023-05-11 04:23:37 +02:00
}
2023-05-13 19:14:59 +02:00
collector = make_collector ( )
chat_collector = make_collector ( )
2023-05-07 08:50:12 +02:00
2023-05-15 02:44:52 +02:00
def feed_data_into_collector ( corpus , chunk_len , chunk_sep ) :
2023-05-13 17:50:19 +02:00
global collector
2023-05-07 20:01:14 +02:00
# Defining variables
chunk_len = int ( chunk_len )
2023-05-15 02:44:52 +02:00
chunk_sep = chunk_sep . replace ( r ' \ n ' , ' \n ' )
2023-05-07 09:49:02 +02:00
cumulative = ' '
2023-05-07 20:01:14 +02:00
# Breaking the data into chunks and adding those to the db
2023-05-07 09:49:02 +02:00
cumulative + = " Breaking the input dataset... \n \n "
yield cumulative
2023-05-15 02:44:52 +02:00
if chunk_sep :
data_chunks = corpus . split ( chunk_sep )
data_chunks = [ [ data_chunk [ i : i + chunk_len ] for i in range ( 0 , len ( data_chunk ) , chunk_len ) ] for data_chunk in data_chunks ]
data_chunks = [ x for y in data_chunks for x in y ]
else :
data_chunks = [ corpus [ i : i + chunk_len ] for i in range ( 0 , len ( corpus ) , chunk_len ) ]
2023-05-20 23:42:17 +02:00
2023-05-07 09:49:02 +02:00
cumulative + = f " { len ( data_chunks ) } chunks have been found. \n \n Adding the chunks to the database... \n \n "
yield cumulative
2023-05-13 17:50:19 +02:00
add_chunks_to_collector ( data_chunks , collector )
2023-05-07 09:49:02 +02:00
cumulative + = " Done. "
yield cumulative
2023-05-15 02:44:52 +02:00
def feed_file_into_collector ( file , chunk_len , chunk_sep ) :
2023-05-07 09:49:02 +02:00
yield ' Reading the input dataset... \n \n '
text = file . decode ( ' utf-8 ' )
2023-05-15 02:44:52 +02:00
for i in feed_data_into_collector ( text , chunk_len , chunk_sep ) :
2023-05-07 09:49:02 +02:00
yield i
2023-05-15 02:44:52 +02:00
def feed_url_into_collector ( urls , chunk_len , chunk_sep , strong_cleanup , threads ) :
2023-05-07 16:07:16 +02:00
all_text = ' '
cumulative = ' '
2023-05-12 19:19:55 +02:00
urls = urls . strip ( ) . split ( ' \n ' )
cumulative + = f ' Loading { len ( urls ) } URLs with { threads } threads... \n \n '
yield cumulative
for update , contents in download_urls ( urls , threads = threads ) :
yield cumulative + update
cumulative + = ' Processing the HTML sources... '
yield cumulative
for content in contents :
2023-07-12 00:02:49 +02:00
soup = BeautifulSoup ( content , features = " lxml " )
2023-05-07 16:07:16 +02:00
for script in soup ( [ " script " , " style " ] ) :
script . extract ( )
2023-05-11 04:23:37 +02:00
strings = soup . stripped_strings
if strong_cleanup :
strings = [ s for s in strings if re . search ( " [A-Za-z] " , s ) ]
text = ' \n ' . join ( [ s . strip ( ) for s in strings ] )
2023-05-07 16:07:16 +02:00
all_text + = text
2023-05-15 02:44:52 +02:00
for i in feed_data_into_collector ( all_text , chunk_len , chunk_sep ) :
2023-05-07 09:49:02 +02:00
yield i
2023-05-07 08:50:12 +02:00
2023-05-25 15:22:45 +02:00
def apply_settings ( chunk_count , chunk_count_initial , time_weight ) :
global params
params [ ' chunk_count ' ] = int ( chunk_count )
params [ ' chunk_count_initial ' ] = int ( chunk_count_initial )
params [ ' time_weight ' ] = time_weight
settings_to_display = { k : params [ k ] for k in params if k in [ ' chunk_count ' , ' chunk_count_initial ' , ' time_weight ' ] }
2023-05-07 16:30:16 +02:00
yield f " The following settings are now active: { str ( settings_to_display ) } "
2023-05-13 17:50:19 +02:00
def custom_generate_chat_prompt ( user_input , state , * * kwargs ) :
global chat_collector
2023-08-09 04:26:28 +02:00
# get history as being modified when using regenerate.
history = kwargs [ ' history ' ]
2023-07-04 05:03:30 +02:00
2023-05-13 17:50:19 +02:00
if state [ ' mode ' ] == ' instruct ' :
2023-05-25 15:22:45 +02:00
results = collector . get_sorted ( user_input , n_results = params [ ' chunk_count ' ] )
2023-05-13 19:23:02 +02:00
additional_context = ' \n Your reply should be based on the context below: \n \n ' + ' \n ' . join ( results )
2023-05-13 17:50:19 +02:00
user_input + = additional_context
else :
def make_single_exchange ( id_ ) :
output = ' '
2023-07-04 05:03:30 +02:00
output + = f " { state [ ' name1 ' ] } : { history [ ' internal ' ] [ id_ ] [ 0 ] } \n "
output + = f " { state [ ' name2 ' ] } : { history [ ' internal ' ] [ id_ ] [ 1 ] } \n "
2023-05-13 17:50:19 +02:00
return output
2023-07-04 05:03:30 +02:00
if len ( history [ ' internal ' ] ) > params [ ' chunk_count ' ] and user_input != ' ' :
2023-05-13 17:50:19 +02:00
chunks = [ ]
2023-07-04 05:03:30 +02:00
hist_size = len ( history [ ' internal ' ] )
2023-07-12 20:33:25 +02:00
for i in range ( hist_size - 1 ) :
2023-05-13 17:50:19 +02:00
chunks . append ( make_single_exchange ( i ) )
add_chunks_to_collector ( chunks , chat_collector )
2023-07-04 05:03:30 +02:00
query = ' \n ' . join ( history [ ' internal ' ] [ - 1 ] + [ user_input ] )
2023-05-13 17:50:19 +02:00
try :
2023-05-25 15:22:45 +02:00
best_ids = chat_collector . get_ids_sorted ( query , n_results = params [ ' chunk_count ' ] , n_initial = params [ ' chunk_count_initial ' ] , time_weight = params [ ' time_weight ' ] )
2023-05-13 17:50:19 +02:00
additional_context = ' \n '
for id_ in best_ids :
2023-07-04 05:03:30 +02:00
if history [ ' internal ' ] [ id_ ] [ 0 ] != ' <|BEGIN-VISIBLE-CHAT|> ' :
2023-05-13 17:50:19 +02:00
additional_context + = make_single_exchange ( id_ )
2023-05-22 03:42:34 +02:00
logger . warning ( f ' Adding the following new context: \n { additional_context } ' )
2023-05-13 17:50:19 +02:00
state [ ' context ' ] = state [ ' context ' ] . strip ( ) + ' \n ' + additional_context
2023-05-20 23:42:17 +02:00
kwargs [ ' history ' ] = {
2023-07-04 05:03:30 +02:00
' internal ' : [ history [ ' internal ' ] [ i ] for i in range ( hist_size ) if i not in best_ids ] ,
2023-05-20 23:42:17 +02:00
' visible ' : ' '
}
2023-05-13 17:50:19 +02:00
except RuntimeError :
2023-05-22 03:42:34 +02:00
logger . error ( " Couldn ' t query the database, moving on... " )
2023-05-13 17:50:19 +02:00
return chat . generate_chat_prompt ( user_input , state , * * kwargs )
def remove_special_tokens ( string ) :
2023-05-13 19:14:59 +02:00
pattern = r ' (< \ |begin-user-input \ |>|< \ |end-user-input \ |>|< \ |injection-point \ |>) '
return re . sub ( pattern , ' ' , string )
2023-05-13 17:50:19 +02:00
2023-05-07 08:50:12 +02:00
def input_modifier ( string ) :
2023-05-07 20:01:14 +02:00
if shared . is_chat ( ) :
return string
2023-05-07 08:50:12 +02:00
# Find the user input
2023-05-07 16:54:26 +02:00
pattern = re . compile ( r " < \ |begin-user-input \ |>(.*?)< \ |end-user-input \ |> " , re . DOTALL )
2023-05-07 08:50:12 +02:00
match = re . search ( pattern , string )
if match :
2023-05-07 16:54:26 +02:00
user_input = match . group ( 1 ) . strip ( )
2023-05-07 08:50:12 +02:00
2023-05-13 19:14:59 +02:00
# Get the most similar chunks
2023-05-25 15:22:45 +02:00
results = collector . get_sorted ( user_input , n_results = params [ ' chunk_count ' ] )
2023-05-07 08:50:12 +02:00
2023-05-13 19:14:59 +02:00
# Make the injection
string = string . replace ( ' <|injection-point|> ' , ' \n ' . join ( results ) )
2023-05-07 08:50:12 +02:00
2023-05-13 19:14:59 +02:00
return remove_special_tokens ( string )
2023-05-07 08:50:12 +02:00
def ui ( ) :
2023-05-07 16:30:16 +02:00
with gr . Accordion ( " Click for more information... " , open = False ) :
gr . Markdown ( textwrap . dedent ( """
## About
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
This extension takes a dataset as input , breaks it into chunks , and adds the result to a local / offline Chroma database .
2023-05-07 08:50:12 +02:00
2023-05-13 17:50:19 +02:00
The database is then queried during inference time to get the excerpts that are closest to your input . The idea is to create an arbitrarily large pseudo context .
2023-05-07 08:50:12 +02:00
2023-05-13 17:50:19 +02:00
The core methodology was developed and contributed by kaiokendev , who is working on improvements to the method in this repository : https : / / github . com / kaiokendev / superbig
2023-05-07 18:29:49 +02:00
2023-05-13 17:50:19 +02:00
## Data input
2023-05-07 20:01:14 +02:00
2023-05-13 17:50:19 +02:00
Start by entering some data in the interface below and then clicking on " Load data " .
2023-05-07 08:50:12 +02:00
2023-05-13 17:50:19 +02:00
Each time you load some new data , the old chunks are discarded .
2023-05-07 08:50:12 +02:00
2023-05-13 17:50:19 +02:00
## Chat mode
2023-05-07 08:50:12 +02:00
2023-05-13 17:50:19 +02:00
#### Instruct
2023-05-07 08:50:12 +02:00
2023-05-13 17:50:19 +02:00
On each turn , the chunks will be compared to your current input and the most relevant matches will be appended to the input in the following format :
2023-05-07 08:50:12 +02:00
2023-05-13 17:50:19 +02:00
` ` `
Consider the excerpts below as additional context :
. . .
` ` `
2023-05-13 19:14:59 +02:00
The injection doesn ' t make it into the chat history. It is only used in the current generation.
2023-05-13 17:50:19 +02:00
#### Regular chat
The chunks from the external data sources are ignored , and the chroma database is built based on the chat history instead . The most relevant past exchanges relative to the present input are added to the context string . This way , the extension acts as a long term memory .
## Notebook/default modes
Your question must be manually specified between ` < | begin - user - input | > ` and ` < | end - user - input | > ` tags , and the injection point must be specified with ` < | injection - point | > ` .
The special tokens mentioned above ( ` < | begin - user - input | > ` , ` < | end - user - input | > ` , and ` < | injection - point | > ` ) are removed in the background before the text generation begins .
Here is an example in Vicuna 1.1 format :
2023-05-07 08:50:12 +02:00
2023-05-07 16:30:16 +02:00
` ` `
2023-05-12 19:19:55 +02:00
A chat between a curious user and an artificial intelligence assistant . The assistant gives helpful , detailed , and polite answers to the user ' s questions.
2023-05-07 08:50:12 +02:00
2023-05-12 19:19:55 +02:00
USER :
2023-05-07 08:50:12 +02:00
2023-05-07 16:54:26 +02:00
< | begin - user - input | >
2023-05-12 19:19:55 +02:00
What datasets are mentioned in the text below ?
2023-05-07 16:54:26 +02:00
< | end - user - input | >
2023-05-07 16:30:16 +02:00
2023-05-12 19:19:55 +02:00
< | injection - point | >
ASSISTANT :
2023-05-07 16:30:16 +02:00
` ` `
2023-05-12 19:19:55 +02:00
⚠ ️ For best results , make sure to remove the spaces and new line characters after ` ASSISTANT : ` .
2023-05-07 16:30:16 +02:00
* This extension is currently experimental and under development . *
""" ))
2023-05-07 08:50:12 +02:00
2023-05-13 17:50:19 +02:00
with gr . Row ( ) :
with gr . Column ( min_width = 600 ) :
with gr . Tab ( " Text input " ) :
data_input = gr . Textbox ( lines = 20 , label = ' Input data ' )
update_data = gr . Button ( ' Load data ' )
with gr . Tab ( " URL input " ) :
url_input = gr . Textbox ( lines = 10 , label = ' Input URLs ' , info = ' Enter one or more URLs separated by newline characters. ' )
strong_cleanup = gr . Checkbox ( value = params [ ' strong_cleanup ' ] , label = ' Strong cleanup ' , info = ' Only keeps html elements that look like long-form text. ' )
threads = gr . Number ( value = params [ ' threads ' ] , label = ' Threads ' , info = ' The number of threads to use while downloading the URLs. ' , precision = 0 )
update_url = gr . Button ( ' Load data ' )
with gr . Tab ( " File input " ) :
file_input = gr . File ( label = ' Input file ' , type = ' binary ' )
update_file = gr . Button ( ' Load data ' )
with gr . Tab ( " Generation settings " ) :
chunk_count = gr . Number ( value = params [ ' chunk_count ' ] , label = ' Chunk count ' , info = ' The number of closest-matching chunks to include in the prompt. ' )
2023-05-25 15:22:45 +02:00
gr . Markdown ( ' Time weighting (optional, used in to make recently added chunks more likely to appear) ' )
time_weight = gr . Slider ( 0 , 1 , value = params [ ' time_weight ' ] , label = ' Time weight ' , info = ' Defines the strength of the time weighting. 0 = no time weighting. ' )
chunk_count_initial = gr . Number ( value = params [ ' chunk_count_initial ' ] , label = ' Initial chunk count ' , info = ' The number of closest-matching chunks retrieved for time weight reordering in chat mode. This should be >= chunk count. -1 = All chunks are retrieved. Only used if time_weight > 0. ' )
2023-05-13 17:50:19 +02:00
update_settings = gr . Button ( ' Apply changes ' )
chunk_len = gr . Number ( value = params [ ' chunk_length ' ] , label = ' Chunk length ' , info = ' In characters, not tokens. This value is used when you click on " Load data " . ' )
2023-05-15 02:44:52 +02:00
chunk_sep = gr . Textbox ( value = params [ ' chunk_separator ' ] , label = ' Chunk separator ' , info = ' Used to manually split chunks. Manually split chunks longer than chunk length are split again. This value is used when you click on " Load data " . ' )
2023-05-13 17:50:19 +02:00
with gr . Column ( ) :
last_updated = gr . Markdown ( )
2023-05-15 02:44:52 +02:00
update_data . click ( feed_data_into_collector , [ data_input , chunk_len , chunk_sep ] , last_updated , show_progress = False )
update_url . click ( feed_url_into_collector , [ url_input , chunk_len , chunk_sep , strong_cleanup , threads ] , last_updated , show_progress = False )
update_file . click ( feed_file_into_collector , [ file_input , chunk_len , chunk_sep ] , last_updated , show_progress = False )
2023-05-25 15:22:45 +02:00
update_settings . click ( apply_settings , [ chunk_count , chunk_count_initial , time_weight ] , last_updated , show_progress = False )