diff --git a/extensions/superbig/script.py b/extensions/superbig/script.py index 122072b2..48fc2e3c 100644 --- a/extensions/superbig/script.py +++ b/extensions/superbig/script.py @@ -65,13 +65,15 @@ class SentenceTransformerEmbedder(Embedder): embedder = SentenceTransformerEmbedder() collector = ChromaCollector(embedder) +chunk_count = 5 -def feed_data_into_collector(corpus): - global collector +def feed_data_into_collector(corpus, chunk_len, _chunk_count): + global collector, chunk_count + chunk_count = int(_chunk_count) + chunk_len = int(chunk_len) cumulative = '' - chunk_len = 700 cumulative += "Breaking the input dataset...\n\n" yield cumulative data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)] @@ -83,14 +85,14 @@ def feed_data_into_collector(corpus): yield cumulative -def feed_file_into_collector(file): +def feed_file_into_collector(file, chunk_len, chunk_count): yield 'Reading the input dataset...\n\n' text = file.decode('utf-8') - for i in feed_data_into_collector(text): + for i in feed_data_into_collector(text, chunk_len, chunk_count): yield i -def feed_url_into_collector(url): +def feed_url_into_collector(url, chunk_len, chunk_count): yield 'Loading the URL...' html = urlopen(url).read() soup = BeautifulSoup(html, features="html.parser") @@ -101,7 +103,7 @@ def feed_url_into_collector(url): lines = (line.strip() for line in text.splitlines()) chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) text = '\n\n'.join(chunk for chunk in chunks if chunk) - for i in feed_data_into_collector(text): + for i in feed_data_into_collector(text, chunk_len, chunk_count): yield i @@ -115,8 +117,8 @@ def input_modifier(string): else: user_input = '' - # Get the 5 most similar chunks - results = collector.get(user_input, n_results=5) + # Get the most similar chunks + results = collector.get(user_input, n_results=chunk_count) # Make the replacements string = string.replace('<|begin-user-input|>', '') @@ -178,9 +180,13 @@ def ui(): file_input = gr.File(label='Input file', type='binary') update_file = gr.Button('Apply') + with gr.Row(): + chunk_len = gr.Number(value=700, label='Chunk length', info='In characters, not tokens') + chunk_count = gr.Number(value=5, label='Chunk count', info='The number of closest-matching chunks to include in the prompt') + with gr.Column(): last_updated = gr.Markdown() - update_data.click(feed_data_into_collector, data_input, last_updated, show_progress=False) - update_url.click(feed_url_into_collector, url_input, last_updated, show_progress=False) - update_file.click(feed_file_into_collector, file_input, last_updated, show_progress=False) + update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_count], last_updated, show_progress=False) + update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_count], last_updated, show_progress=False) + update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_count], last_updated, show_progress=False)