From 0480d70bfd635cd82565b07f12f448ce316cedb9 Mon Sep 17 00:00:00 2001 From: CG Date: Wed, 5 Jul 2023 14:48:34 -0700 Subject: [PATCH] updated superbooga with search integration and semantic filtering --- extensions/superbooga/script.py | 89 +++++++++++++++------------------ 1 file changed, 40 insertions(+), 49 deletions(-) diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index f56a0897..79d90dba 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -12,6 +12,8 @@ from .download_urls import download_urls import requests import json +from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity params = { 'chunk_count': 5, @@ -20,6 +22,8 @@ params = { 'chunk_length': 700, 'chunk_separator': '', 'strong_cleanup': False, + 'semantic_cleanup': True, + 'semantic_weight': 0.5, 'threads': 4, } @@ -60,7 +64,6 @@ def feed_file_into_collector(file, chunk_len, chunk_sep): def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads): - print("feed_url_into_collector") all_text = '' cumulative = '' @@ -87,13 +90,18 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads) for i in feed_data_into_collector(all_text, chunk_len, chunk_sep): yield i +def calculate_semantic_similarity(query_embedding, target_embedding): + # Calculate cosine similarity between the query embedding and the target embedding + similarity = cosine_similarity(query_embedding.reshape(1, -1), target_embedding.reshape(1, -1)) + return similarity[0][0] + def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads): # Load parameters from the config file with open('custom_search_engine_keys.json') as key_file: key = json.load(key_file) - print("=== Searching google ===") - print("-- " + str(query)) + model = SentenceTransformer('all-MiniLM-L6-v2') + query_embedding = model.encode([query])[0] # Set up API endpoint and parameters url = "https://www.googleapis.com/customsearch/v1" @@ -113,12 +121,6 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search") return query - # constructing the URL - # doc: https://developers.google.com/custom-search/v1/using_rest - # calculating start, (page=2) => (start=11), (page=3) => (start=21) - page = 1 - start = (page - 1) * 10 + 1 - # Send API request response = requests.get(url, params=params) @@ -127,49 +129,36 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema # get the result items search_items = data.get("items") + # iterate over 10 results found - search_urls = "" + urls = "" for i, search_item in enumerate(search_items, start=1): - try: - long_description = search_item["pagemap"]["metatags"][0]["og:description"] - except KeyError: - long_description = "N/A" - # get the page title - title = search_item.get("title") - # page snippet - snippet = search_item.get("snippet") - # alternatively, you can get the HTML snippet (bolded keywords) - html_snippet = search_item.get("htmlSnippet") - # extract the page url + if semantic_cleanup: + # get titles and descriptions and use that to semantically weight the search result + # get the page title + title = search_item.get("title") + # page snippet + snippet = search_item.get("snippet") + + target_sentence = str(title) + " " + str(snippet) + target_embedding = model.encode([target_sentence])[0] + + similarity_score = calculate_semantic_similarity(query_embedding, target_embedding) + + if similarity_score < semantic_requirement: + continue + + + # extract the page url and add it to the urls to download link = search_item.get("link") - search_urls += link + "\n" + urls += link + "\n" - # TODO don't clone feed_url_into_collector - all_text = '' - cumulative = '' + # Call the original feed_url_into_collector function instead of duplicating the code + result_generator = feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads) - urls = search_urls.strip().split('\n') - cumulative += f'Loading {len(urls)} URLs with {threads} threads...\n\n' - yield cumulative - for update, contents in download_urls(urls, threads=threads): - yield cumulative + update - - cumulative += 'Processing the HTML sources...' - yield cumulative - for content in contents: - soup = BeautifulSoup(content, features="html.parser") - for script in soup(["script", "style"]): - script.extract() - - strings = soup.stripped_strings - if strong_cleanup: - strings = [s for s in strings if re.search("[A-Za-z] ", s)] - - text = '\n'.join([s.strip() for s in strings]) - all_text += text - - for i in feed_data_into_collector(all_text, chunk_len, chunk_sep): - yield i + # Consume the yielded values + for result in result_generator: + yield result def apply_settings(chunk_count, chunk_count_initial, time_weight): @@ -330,8 +319,10 @@ def ui(): with gr.Tab("Search input"): search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB') search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.') - semantic_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Require semantic similarity (not implemented)', info='Only download pages with similar titles/snippets to the search') # TODO cdg - semantic_requirement = gr.Slider(0, 1, value=params['time_weight'], label='Semantic similarity requirement (not implemented)', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') # TODO cdg + + semantic_cleanup = gr.Checkbox(value=params['semantic_cleanup'], label='Require semantic similarity', info='Only download pages with similar titles/snippets to the search based on a semantic search') + semantic_requirement = gr.Slider(0, 1, value=params['semantic_weight'], label='Semantic similarity requirement', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') + search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0) update_search = gr.Button('Load data')