Add new feature: Enable search engine integration in script.py

This commit is contained in:
CG 2023-07-05 13:10:58 -07:00
parent b67c362735
commit ab4ca9a3dd

View File

@ -10,6 +10,9 @@ from modules.logging_colors import logger
from .chromadb import add_chunks_to_collector, make_collector from .chromadb import add_chunks_to_collector, make_collector
from .download_urls import download_urls from .download_urls import download_urls
import requests
import json
params = { params = {
'chunk_count': 5, 'chunk_count': 5,
'chunk_count_initial': 10, 'chunk_count_initial': 10,
@ -57,6 +60,7 @@ def feed_file_into_collector(file, chunk_len, chunk_sep):
def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads): def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads):
print("feed_url_into_collector")
all_text = '' all_text = ''
cumulative = '' cumulative = ''
@ -83,6 +87,90 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep): for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
yield i yield i
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
# Load parameters from the config file
with open('custom_search_engine_keys.json') as key_file:
key = json.load(key_file)
print("=== Searching google ===")
print("-- " + str(query))
# Set up API endpoint and parameters
url = "https://www.googleapis.com/customsearch/v1"
# Retrieve the values from the config dictionary
params = {
"key": key.get("key", "default_key_value"),
"cx": key.get("cx", "default_custom_engine_value"),
"q": str(query),
}
if "default_key_value" in str(params):
print("You need to provide an API key, by modifying the custom_search_engine_keys.json in oobabooga_windows \ text-generation-webui.\nSkipping search")
return query
if "default_custom_engine_value" in str(params):
print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search")
return query
# constructing the URL
# doc: https://developers.google.com/custom-search/v1/using_rest
# calculating start, (page=2) => (start=11), (page=3) => (start=21)
page = 1
start = (page - 1) * 10 + 1
# Send API request
response = requests.get(url, params=params)
# Parse JSON response
data = response.json()
# get the result items
search_items = data.get("items")
# iterate over 10 results found
search_urls = ""
for i, search_item in enumerate(search_items, start=1):
try:
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
except KeyError:
long_description = "N/A"
# get the page title
title = search_item.get("title")
# page snippet
snippet = search_item.get("snippet")
# alternatively, you can get the HTML snippet (bolded keywords)
html_snippet = search_item.get("htmlSnippet")
# extract the page url
link = search_item.get("link")
search_urls += link + "\n"
# TODO don't clone feed_url_into_collector
all_text = ''
cumulative = ''
urls = search_urls.strip().split('\n')
cumulative += f'Loading {len(urls)} URLs with {threads} threads...\n\n'
yield cumulative
for update, contents in download_urls(urls, threads=threads):
yield cumulative + update
cumulative += 'Processing the HTML sources...'
yield cumulative
for content in contents:
soup = BeautifulSoup(content, features="html.parser")
for script in soup(["script", "style"]):
script.extract()
strings = soup.stripped_strings
if strong_cleanup:
strings = [s for s in strings if re.search("[A-Za-z] ", s)]
text = '\n'.join([s.strip() for s in strings])
all_text += text
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
yield i
def apply_settings(chunk_count, chunk_count_initial, time_weight): def apply_settings(chunk_count, chunk_count_initial, time_weight):
global params global params
@ -96,39 +184,38 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
def custom_generate_chat_prompt(user_input, state, **kwargs): def custom_generate_chat_prompt(user_input, state, **kwargs):
global chat_collector global chat_collector
history = state['history']
if state['mode'] == 'instruct': if state['mode'] == 'instruct':
results = collector.get_sorted(user_input, n_results=params['chunk_count']) results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results) additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
user_input += additional_context user_input += additional_context
logger.info(f'\n\n=== === ===\nAdding the following new context:\n{additional_context}\n=== === ===\n')
else: else:
def make_single_exchange(id_): def make_single_exchange(id_):
output = '' output = ''
output += f"{state['name1']}: {history['internal'][id_][0]}\n" output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n"
output += f"{state['name2']}: {history['internal'][id_][1]}\n" output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
return output return output
if len(history['internal']) > params['chunk_count'] and user_input != '': if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
chunks = [] chunks = []
hist_size = len(history['internal']) hist_size = len(shared.history['internal'])
for i in range(hist_size-1): for i in range(hist_size-1):
chunks.append(make_single_exchange(i)) chunks.append(make_single_exchange(i))
add_chunks_to_collector(chunks, chat_collector) add_chunks_to_collector(chunks, chat_collector)
query = '\n'.join(history['internal'][-1] + [user_input]) query = '\n'.join(shared.history['internal'][-1] + [user_input])
try: try:
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight']) best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
additional_context = '\n' additional_context = '\n'
for id_ in best_ids: for id_ in best_ids:
if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
additional_context += make_single_exchange(id_) additional_context += make_single_exchange(id_)
logger.warning(f'Adding the following new context:\n{additional_context}') logger.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context state['context'] = state['context'].strip() + '\n' + additional_context
kwargs['history'] = { kwargs['history'] = {
'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids], 'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': '' 'visible': ''
} }
except RuntimeError: except RuntimeError:
@ -240,6 +327,36 @@ def ui():
file_input = gr.File(label='Input file', type='binary') file_input = gr.File(label='Input file', type='binary')
update_file = gr.Button('Load data') update_file = gr.Button('Load data')
with gr.Tab("Search input"):
search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB')
search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
semantic_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Require semantic similarity (not implemented)', info='Only download pages with similar titles/snippets to the search') # TODO cdg
semantic_requirement = gr.Slider(0, 1, value=params['time_weight'], label='Semantic similarity requirement (not implemented)', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') # TODO cdg
search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
update_search = gr.Button('Load data')
with gr.Accordion("Click for more information...", open=False):
gr.Markdown(textwrap.dedent("""
# installation/setup
Please follow the instruction found here to setup a custom search engine with Google.
https://www.thepythoncode.com/article/use-google-custom-search-engine-api-in-python
create a file called "custom_search_engine_keys.json"
Paste this text in it and replace with your values from the previous step:
"
{
"key": "Custom search engine key",
"cx": "Custom search engine cx number"
}
"
# usage
Enter a search query above. Press the load data button. This data will be added to the local chromaDB to be read into context at runtime.
"""))
with gr.Tab("Generation settings"): with gr.Tab("Generation settings"):
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.') chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)') gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
@ -256,4 +373,5 @@ def ui():
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False) update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False) update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False) update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated,show_progress=False)
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False) update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)