From ab4ca9a3dd29a9ac8315de3eef67e7c2d27df046 Mon Sep 17 00:00:00 2001 From: CG Date: Wed, 5 Jul 2023 13:10:58 -0700 Subject: [PATCH] Add new feature: Enable search engine integration in script.py --- extensions/superbooga/script.py | 136 +++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 9 deletions(-) diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index c0d3f8eb..f56a0897 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -10,6 +10,9 @@ from modules.logging_colors import logger from .chromadb import add_chunks_to_collector, make_collector from .download_urls import download_urls +import requests +import json + params = { 'chunk_count': 5, 'chunk_count_initial': 10, @@ -57,6 +60,7 @@ def feed_file_into_collector(file, chunk_len, chunk_sep): def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads): + print("feed_url_into_collector") all_text = '' cumulative = '' @@ -83,6 +87,90 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads) for i in feed_data_into_collector(all_text, chunk_len, chunk_sep): yield i +def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads): + # Load parameters from the config file + with open('custom_search_engine_keys.json') as key_file: + key = json.load(key_file) + + print("=== Searching google ===") + print("-- " + str(query)) + + # Set up API endpoint and parameters + url = "https://www.googleapis.com/customsearch/v1" + + # Retrieve the values from the config dictionary + params = { + "key": key.get("key", "default_key_value"), + "cx": key.get("cx", "default_custom_engine_value"), + "q": str(query), + } + + if "default_key_value" in str(params): + print("You need to provide an API key, by modifying the custom_search_engine_keys.json in oobabooga_windows \ text-generation-webui.\nSkipping search") + return query + + if "default_custom_engine_value" in str(params): + print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search") + return query + + # constructing the URL + # doc: https://developers.google.com/custom-search/v1/using_rest + # calculating start, (page=2) => (start=11), (page=3) => (start=21) + page = 1 + start = (page - 1) * 10 + 1 + + # Send API request + response = requests.get(url, params=params) + + # Parse JSON response + data = response.json() + + # get the result items + search_items = data.get("items") + # iterate over 10 results found + search_urls = "" + for i, search_item in enumerate(search_items, start=1): + try: + long_description = search_item["pagemap"]["metatags"][0]["og:description"] + except KeyError: + long_description = "N/A" + # get the page title + title = search_item.get("title") + # page snippet + snippet = search_item.get("snippet") + # alternatively, you can get the HTML snippet (bolded keywords) + html_snippet = search_item.get("htmlSnippet") + # extract the page url + link = search_item.get("link") + search_urls += link + "\n" + + # TODO don't clone feed_url_into_collector + all_text = '' + cumulative = '' + + urls = search_urls.strip().split('\n') + cumulative += f'Loading {len(urls)} URLs with {threads} threads...\n\n' + yield cumulative + for update, contents in download_urls(urls, threads=threads): + yield cumulative + update + + cumulative += 'Processing the HTML sources...' + yield cumulative + for content in contents: + soup = BeautifulSoup(content, features="html.parser") + for script in soup(["script", "style"]): + script.extract() + + strings = soup.stripped_strings + if strong_cleanup: + strings = [s for s in strings if re.search("[A-Za-z] ", s)] + + text = '\n'.join([s.strip() for s in strings]) + all_text += text + + for i in feed_data_into_collector(all_text, chunk_len, chunk_sep): + yield i + def apply_settings(chunk_count, chunk_count_initial, time_weight): global params @@ -96,39 +184,38 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight): def custom_generate_chat_prompt(user_input, state, **kwargs): global chat_collector - history = state['history'] - if state['mode'] == 'instruct': results = collector.get_sorted(user_input, n_results=params['chunk_count']) additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results) user_input += additional_context + logger.info(f'\n\n=== === ===\nAdding the following new context:\n{additional_context}\n=== === ===\n') else: def make_single_exchange(id_): output = '' - output += f"{state['name1']}: {history['internal'][id_][0]}\n" - output += f"{state['name2']}: {history['internal'][id_][1]}\n" + output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n" + output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n" return output - if len(history['internal']) > params['chunk_count'] and user_input != '': + if len(shared.history['internal']) > params['chunk_count'] and user_input != '': chunks = [] - hist_size = len(history['internal']) + hist_size = len(shared.history['internal']) for i in range(hist_size-1): chunks.append(make_single_exchange(i)) add_chunks_to_collector(chunks, chat_collector) - query = '\n'.join(history['internal'][-1] + [user_input]) + query = '\n'.join(shared.history['internal'][-1] + [user_input]) try: best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight']) additional_context = '\n' for id_ in best_ids: - if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': + if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': additional_context += make_single_exchange(id_) logger.warning(f'Adding the following new context:\n{additional_context}') state['context'] = state['context'].strip() + '\n' + additional_context kwargs['history'] = { - 'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids], + 'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids], 'visible': '' } except RuntimeError: @@ -240,6 +327,36 @@ def ui(): file_input = gr.File(label='Input file', type='binary') update_file = gr.Button('Load data') + with gr.Tab("Search input"): + search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB') + search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.') + semantic_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Require semantic similarity (not implemented)', info='Only download pages with similar titles/snippets to the search') # TODO cdg + semantic_requirement = gr.Slider(0, 1, value=params['time_weight'], label='Semantic similarity requirement (not implemented)', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') # TODO cdg + search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0) + update_search = gr.Button('Load data') + + with gr.Accordion("Click for more information...", open=False): + gr.Markdown(textwrap.dedent(""" + + # installation/setup + Please follow the instruction found here to setup a custom search engine with Google. + https://www.thepythoncode.com/article/use-google-custom-search-engine-api-in-python + + create a file called "custom_search_engine_keys.json" + + Paste this text in it and replace with your values from the previous step: + " + { + "key": "Custom search engine key", + "cx": "Custom search engine cx number" + } + " + + # usage + Enter a search query above. Press the load data button. This data will be added to the local chromaDB to be read into context at runtime. + + """)) + with gr.Tab("Generation settings"): chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.') gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)') @@ -256,4 +373,5 @@ def ui(): update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False) update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False) update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False) + update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated,show_progress=False) update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)