diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index 85a7411c..cc8454da 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -13,6 +13,7 @@ from .download_urls import download_urls params = { 'chunk_count': 5, 'chunk_length': 700, + 'chunk_separator': '', 'strong_cleanup': False, 'threads': 4, } @@ -22,17 +23,23 @@ chat_collector = make_collector() chunk_count = 5 -def feed_data_into_collector(corpus, chunk_len): +def feed_data_into_collector(corpus, chunk_len, chunk_sep): global collector # Defining variables chunk_len = int(chunk_len) + chunk_sep = chunk_sep.replace(r'\n', '\n') cumulative = '' # Breaking the data into chunks and adding those to the db cumulative += "Breaking the input dataset...\n\n" yield cumulative - data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)] + if chunk_sep: + data_chunks = corpus.split(chunk_sep) + data_chunks = [[data_chunk[i:i + chunk_len] for i in range(0, len(data_chunk), chunk_len)] for data_chunk in data_chunks] + data_chunks = [x for y in data_chunks for x in y] + else: + data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)] cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n" yield cumulative add_chunks_to_collector(data_chunks, collector) @@ -40,14 +47,14 @@ def feed_data_into_collector(corpus, chunk_len): yield cumulative -def feed_file_into_collector(file, chunk_len): +def feed_file_into_collector(file, chunk_len, chunk_sep): yield 'Reading the input dataset...\n\n' text = file.decode('utf-8') - for i in feed_data_into_collector(text, chunk_len): + for i in feed_data_into_collector(text, chunk_len, chunk_sep): yield i -def feed_url_into_collector(urls, chunk_len, strong_cleanup, threads): +def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads): all_text = '' cumulative = '' @@ -71,7 +78,7 @@ def feed_url_into_collector(urls, chunk_len, strong_cleanup, threads): text = '\n'.join([s.strip() for s in strings]) all_text += text - for i in feed_data_into_collector(all_text, chunk_len): + for i in feed_data_into_collector(all_text, chunk_len, chunk_sep): yield i @@ -232,10 +239,11 @@ def ui(): update_settings = gr.Button('Apply changes') chunk_len = gr.Number(value=params['chunk_length'], label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".') + chunk_sep = gr.Textbox(value=params['chunk_separator'], label='Chunk separator', info='Used to manually split chunks. Manually split chunks longer than chunk length are split again. This value is used when you click on "Load data".') with gr.Column(): last_updated = gr.Markdown() - update_data.click(feed_data_into_collector, [data_input, chunk_len], last_updated, show_progress=False) - update_url.click(feed_url_into_collector, [url_input, chunk_len, strong_cleanup, threads], last_updated, show_progress=False) - update_file.click(feed_file_into_collector, [file_input, chunk_len], last_updated, show_progress=False) + update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False) + update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False) + update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False) update_settings.click(apply_settings, [chunk_count], last_updated, show_progress=False)