Add new feature: Enable search engine integration in script.py

This commit is contained in:
CG 2023-07-05 13:10:58 -07:00
parent b67c362735
commit ab4ca9a3dd

View File

@ -10,6 +10,9 @@ from modules.logging_colors import logger
from .chromadb import add_chunks_to_collector, make_collector
from .download_urls import download_urls
import requests
import json
params = {
'chunk_count': 5,
'chunk_count_initial': 10,
@ -57,6 +60,7 @@ def feed_file_into_collector(file, chunk_len, chunk_sep):
def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads):
print("feed_url_into_collector")
all_text = ''
cumulative = ''
@ -83,6 +87,90 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
yield i
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
# Load parameters from the config file
with open('custom_search_engine_keys.json') as key_file:
key = json.load(key_file)
print("=== Searching google ===")
print("-- " + str(query))
# Set up API endpoint and parameters
url = "https://www.googleapis.com/customsearch/v1"
# Retrieve the values from the config dictionary
params = {
"key": key.get("key", "default_key_value"),
"cx": key.get("cx", "default_custom_engine_value"),
"q": str(query),
}
if "default_key_value" in str(params):
print("You need to provide an API key, by modifying the custom_search_engine_keys.json in oobabooga_windows \ text-generation-webui.\nSkipping search")
return query
if "default_custom_engine_value" in str(params):
print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search")
return query
# constructing the URL
# doc: https://developers.google.com/custom-search/v1/using_rest
# calculating start, (page=2) => (start=11), (page=3) => (start=21)
page = 1
start = (page - 1) * 10 + 1
# Send API request
response = requests.get(url, params=params)
# Parse JSON response
data = response.json()
# get the result items
search_items = data.get("items")
# iterate over 10 results found
search_urls = ""
for i, search_item in enumerate(search_items, start=1):
try:
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
except KeyError:
long_description = "N/A"
# get the page title
title = search_item.get("title")
# page snippet
snippet = search_item.get("snippet")
# alternatively, you can get the HTML snippet (bolded keywords)
html_snippet = search_item.get("htmlSnippet")
# extract the page url
link = search_item.get("link")
search_urls += link + "\n"
# TODO don't clone feed_url_into_collector
all_text = ''
cumulative = ''
urls = search_urls.strip().split('\n')
cumulative += f'Loading {len(urls)} URLs with {threads} threads...\n\n'
yield cumulative
for update, contents in download_urls(urls, threads=threads):
yield cumulative + update
cumulative += 'Processing the HTML sources...'
yield cumulative
for content in contents:
soup = BeautifulSoup(content, features="html.parser")
for script in soup(["script", "style"]):
script.extract()
strings = soup.stripped_strings
if strong_cleanup:
strings = [s for s in strings if re.search("[A-Za-z] ", s)]
text = '\n'.join([s.strip() for s in strings])
all_text += text
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
yield i
def apply_settings(chunk_count, chunk_count_initial, time_weight):
global params
@ -96,39 +184,38 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
def custom_generate_chat_prompt(user_input, state, **kwargs):
global chat_collector
history = state['history']
if state['mode'] == 'instruct':
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
user_input += additional_context
logger.info(f'\n\n=== === ===\nAdding the following new context:\n{additional_context}\n=== === ===\n')
else:
def make_single_exchange(id_):
output = ''
output += f"{state['name1']}: {history['internal'][id_][0]}\n"
output += f"{state['name2']}: {history['internal'][id_][1]}\n"
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n"
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
return output
if len(history['internal']) > params['chunk_count'] and user_input != '':
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
chunks = []
hist_size = len(history['internal'])
hist_size = len(shared.history['internal'])
for i in range(hist_size-1):
chunks.append(make_single_exchange(i))
add_chunks_to_collector(chunks, chat_collector)
query = '\n'.join(history['internal'][-1] + [user_input])
query = '\n'.join(shared.history['internal'][-1] + [user_input])
try:
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
additional_context = '\n'
for id_ in best_ids:
if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
additional_context += make_single_exchange(id_)
logger.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context
kwargs['history'] = {
'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': ''
}
except RuntimeError:
@ -240,6 +327,36 @@ def ui():
file_input = gr.File(label='Input file', type='binary')
update_file = gr.Button('Load data')
with gr.Tab("Search input"):
search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB')
search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
semantic_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Require semantic similarity (not implemented)', info='Only download pages with similar titles/snippets to the search') # TODO cdg
semantic_requirement = gr.Slider(0, 1, value=params['time_weight'], label='Semantic similarity requirement (not implemented)', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') # TODO cdg
search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
update_search = gr.Button('Load data')
with gr.Accordion("Click for more information...", open=False):
gr.Markdown(textwrap.dedent("""
# installation/setup
Please follow the instruction found here to setup a custom search engine with Google.
https://www.thepythoncode.com/article/use-google-custom-search-engine-api-in-python
create a file called "custom_search_engine_keys.json"
Paste this text in it and replace with your values from the previous step:
"
{
"key": "Custom search engine key",
"cx": "Custom search engine cx number"
}
"
# usage
Enter a search query above. Press the load data button. This data will be added to the local chromaDB to be read into context at runtime.
"""))
with gr.Tab("Generation settings"):
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
@ -256,4 +373,5 @@ def ui():
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated,show_progress=False)
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)