mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 09:40:20 +01:00
Add new feature: Enable search engine integration in script.py
This commit is contained in:
parent
b67c362735
commit
ab4ca9a3dd
@ -10,6 +10,9 @@ from modules.logging_colors import logger
|
||||
from .chromadb import add_chunks_to_collector, make_collector
|
||||
from .download_urls import download_urls
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
params = {
|
||||
'chunk_count': 5,
|
||||
'chunk_count_initial': 10,
|
||||
@ -57,6 +60,7 @@ def feed_file_into_collector(file, chunk_len, chunk_sep):
|
||||
|
||||
|
||||
def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads):
|
||||
print("feed_url_into_collector")
|
||||
all_text = ''
|
||||
cumulative = ''
|
||||
|
||||
@ -83,6 +87,90 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
||||
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
||||
yield i
|
||||
|
||||
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
|
||||
# Load parameters from the config file
|
||||
with open('custom_search_engine_keys.json') as key_file:
|
||||
key = json.load(key_file)
|
||||
|
||||
print("=== Searching google ===")
|
||||
print("-- " + str(query))
|
||||
|
||||
# Set up API endpoint and parameters
|
||||
url = "https://www.googleapis.com/customsearch/v1"
|
||||
|
||||
# Retrieve the values from the config dictionary
|
||||
params = {
|
||||
"key": key.get("key", "default_key_value"),
|
||||
"cx": key.get("cx", "default_custom_engine_value"),
|
||||
"q": str(query),
|
||||
}
|
||||
|
||||
if "default_key_value" in str(params):
|
||||
print("You need to provide an API key, by modifying the custom_search_engine_keys.json in oobabooga_windows \ text-generation-webui.\nSkipping search")
|
||||
return query
|
||||
|
||||
if "default_custom_engine_value" in str(params):
|
||||
print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search")
|
||||
return query
|
||||
|
||||
# constructing the URL
|
||||
# doc: https://developers.google.com/custom-search/v1/using_rest
|
||||
# calculating start, (page=2) => (start=11), (page=3) => (start=21)
|
||||
page = 1
|
||||
start = (page - 1) * 10 + 1
|
||||
|
||||
# Send API request
|
||||
response = requests.get(url, params=params)
|
||||
|
||||
# Parse JSON response
|
||||
data = response.json()
|
||||
|
||||
# get the result items
|
||||
search_items = data.get("items")
|
||||
# iterate over 10 results found
|
||||
search_urls = ""
|
||||
for i, search_item in enumerate(search_items, start=1):
|
||||
try:
|
||||
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
|
||||
except KeyError:
|
||||
long_description = "N/A"
|
||||
# get the page title
|
||||
title = search_item.get("title")
|
||||
# page snippet
|
||||
snippet = search_item.get("snippet")
|
||||
# alternatively, you can get the HTML snippet (bolded keywords)
|
||||
html_snippet = search_item.get("htmlSnippet")
|
||||
# extract the page url
|
||||
link = search_item.get("link")
|
||||
search_urls += link + "\n"
|
||||
|
||||
# TODO don't clone feed_url_into_collector
|
||||
all_text = ''
|
||||
cumulative = ''
|
||||
|
||||
urls = search_urls.strip().split('\n')
|
||||
cumulative += f'Loading {len(urls)} URLs with {threads} threads...\n\n'
|
||||
yield cumulative
|
||||
for update, contents in download_urls(urls, threads=threads):
|
||||
yield cumulative + update
|
||||
|
||||
cumulative += 'Processing the HTML sources...'
|
||||
yield cumulative
|
||||
for content in contents:
|
||||
soup = BeautifulSoup(content, features="html.parser")
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
strings = soup.stripped_strings
|
||||
if strong_cleanup:
|
||||
strings = [s for s in strings if re.search("[A-Za-z] ", s)]
|
||||
|
||||
text = '\n'.join([s.strip() for s in strings])
|
||||
all_text += text
|
||||
|
||||
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
||||
yield i
|
||||
|
||||
|
||||
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||
global params
|
||||
@ -96,39 +184,38 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
global chat_collector
|
||||
|
||||
history = state['history']
|
||||
|
||||
if state['mode'] == 'instruct':
|
||||
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
||||
user_input += additional_context
|
||||
logger.info(f'\n\n=== === ===\nAdding the following new context:\n{additional_context}\n=== === ===\n')
|
||||
else:
|
||||
|
||||
def make_single_exchange(id_):
|
||||
output = ''
|
||||
output += f"{state['name1']}: {history['internal'][id_][0]}\n"
|
||||
output += f"{state['name2']}: {history['internal'][id_][1]}\n"
|
||||
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n"
|
||||
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
|
||||
return output
|
||||
|
||||
if len(history['internal']) > params['chunk_count'] and user_input != '':
|
||||
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
|
||||
chunks = []
|
||||
hist_size = len(history['internal'])
|
||||
hist_size = len(shared.history['internal'])
|
||||
for i in range(hist_size-1):
|
||||
chunks.append(make_single_exchange(i))
|
||||
|
||||
add_chunks_to_collector(chunks, chat_collector)
|
||||
query = '\n'.join(history['internal'][-1] + [user_input])
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
try:
|
||||
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
|
||||
additional_context = '\n'
|
||||
for id_ in best_ids:
|
||||
if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
additional_context += make_single_exchange(id_)
|
||||
|
||||
logger.warning(f'Adding the following new context:\n{additional_context}')
|
||||
state['context'] = state['context'].strip() + '\n' + additional_context
|
||||
kwargs['history'] = {
|
||||
'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||
'visible': ''
|
||||
}
|
||||
except RuntimeError:
|
||||
@ -240,6 +327,36 @@ def ui():
|
||||
file_input = gr.File(label='Input file', type='binary')
|
||||
update_file = gr.Button('Load data')
|
||||
|
||||
with gr.Tab("Search input"):
|
||||
search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB')
|
||||
search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
|
||||
semantic_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Require semantic similarity (not implemented)', info='Only download pages with similar titles/snippets to the search') # TODO cdg
|
||||
semantic_requirement = gr.Slider(0, 1, value=params['time_weight'], label='Semantic similarity requirement (not implemented)', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') # TODO cdg
|
||||
search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
|
||||
update_search = gr.Button('Load data')
|
||||
|
||||
with gr.Accordion("Click for more information...", open=False):
|
||||
gr.Markdown(textwrap.dedent("""
|
||||
|
||||
# installation/setup
|
||||
Please follow the instruction found here to setup a custom search engine with Google.
|
||||
https://www.thepythoncode.com/article/use-google-custom-search-engine-api-in-python
|
||||
|
||||
create a file called "custom_search_engine_keys.json"
|
||||
|
||||
Paste this text in it and replace with your values from the previous step:
|
||||
"
|
||||
{
|
||||
"key": "Custom search engine key",
|
||||
"cx": "Custom search engine cx number"
|
||||
}
|
||||
"
|
||||
|
||||
# usage
|
||||
Enter a search query above. Press the load data button. This data will be added to the local chromaDB to be read into context at runtime.
|
||||
|
||||
"""))
|
||||
|
||||
with gr.Tab("Generation settings"):
|
||||
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
|
||||
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
|
||||
@ -256,4 +373,5 @@ def ui():
|
||||
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
|
||||
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated,show_progress=False)
|
||||
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)
|
||||
|
Loading…
Reference in New Issue
Block a user