mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 19:09:32 +01:00
Add new feature: Enable search engine integration in script.py
This commit is contained in:
parent
b67c362735
commit
ab4ca9a3dd
@ -10,6 +10,9 @@ from modules.logging_colors import logger
|
|||||||
from .chromadb import add_chunks_to_collector, make_collector
|
from .chromadb import add_chunks_to_collector, make_collector
|
||||||
from .download_urls import download_urls
|
from .download_urls import download_urls
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'chunk_count': 5,
|
'chunk_count': 5,
|
||||||
'chunk_count_initial': 10,
|
'chunk_count_initial': 10,
|
||||||
@ -57,6 +60,7 @@ def feed_file_into_collector(file, chunk_len, chunk_sep):
|
|||||||
|
|
||||||
|
|
||||||
def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads):
|
def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads):
|
||||||
|
print("feed_url_into_collector")
|
||||||
all_text = ''
|
all_text = ''
|
||||||
cumulative = ''
|
cumulative = ''
|
||||||
|
|
||||||
@ -83,6 +87,90 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
|||||||
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
|
||||||
|
# Load parameters from the config file
|
||||||
|
with open('custom_search_engine_keys.json') as key_file:
|
||||||
|
key = json.load(key_file)
|
||||||
|
|
||||||
|
print("=== Searching google ===")
|
||||||
|
print("-- " + str(query))
|
||||||
|
|
||||||
|
# Set up API endpoint and parameters
|
||||||
|
url = "https://www.googleapis.com/customsearch/v1"
|
||||||
|
|
||||||
|
# Retrieve the values from the config dictionary
|
||||||
|
params = {
|
||||||
|
"key": key.get("key", "default_key_value"),
|
||||||
|
"cx": key.get("cx", "default_custom_engine_value"),
|
||||||
|
"q": str(query),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "default_key_value" in str(params):
|
||||||
|
print("You need to provide an API key, by modifying the custom_search_engine_keys.json in oobabooga_windows \ text-generation-webui.\nSkipping search")
|
||||||
|
return query
|
||||||
|
|
||||||
|
if "default_custom_engine_value" in str(params):
|
||||||
|
print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search")
|
||||||
|
return query
|
||||||
|
|
||||||
|
# constructing the URL
|
||||||
|
# doc: https://developers.google.com/custom-search/v1/using_rest
|
||||||
|
# calculating start, (page=2) => (start=11), (page=3) => (start=21)
|
||||||
|
page = 1
|
||||||
|
start = (page - 1) * 10 + 1
|
||||||
|
|
||||||
|
# Send API request
|
||||||
|
response = requests.get(url, params=params)
|
||||||
|
|
||||||
|
# Parse JSON response
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# get the result items
|
||||||
|
search_items = data.get("items")
|
||||||
|
# iterate over 10 results found
|
||||||
|
search_urls = ""
|
||||||
|
for i, search_item in enumerate(search_items, start=1):
|
||||||
|
try:
|
||||||
|
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
|
||||||
|
except KeyError:
|
||||||
|
long_description = "N/A"
|
||||||
|
# get the page title
|
||||||
|
title = search_item.get("title")
|
||||||
|
# page snippet
|
||||||
|
snippet = search_item.get("snippet")
|
||||||
|
# alternatively, you can get the HTML snippet (bolded keywords)
|
||||||
|
html_snippet = search_item.get("htmlSnippet")
|
||||||
|
# extract the page url
|
||||||
|
link = search_item.get("link")
|
||||||
|
search_urls += link + "\n"
|
||||||
|
|
||||||
|
# TODO don't clone feed_url_into_collector
|
||||||
|
all_text = ''
|
||||||
|
cumulative = ''
|
||||||
|
|
||||||
|
urls = search_urls.strip().split('\n')
|
||||||
|
cumulative += f'Loading {len(urls)} URLs with {threads} threads...\n\n'
|
||||||
|
yield cumulative
|
||||||
|
for update, contents in download_urls(urls, threads=threads):
|
||||||
|
yield cumulative + update
|
||||||
|
|
||||||
|
cumulative += 'Processing the HTML sources...'
|
||||||
|
yield cumulative
|
||||||
|
for content in contents:
|
||||||
|
soup = BeautifulSoup(content, features="html.parser")
|
||||||
|
for script in soup(["script", "style"]):
|
||||||
|
script.extract()
|
||||||
|
|
||||||
|
strings = soup.stripped_strings
|
||||||
|
if strong_cleanup:
|
||||||
|
strings = [s for s in strings if re.search("[A-Za-z] ", s)]
|
||||||
|
|
||||||
|
text = '\n'.join([s.strip() for s in strings])
|
||||||
|
all_text += text
|
||||||
|
|
||||||
|
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
|
||||||
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||||
global params
|
global params
|
||||||
@ -96,39 +184,38 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
|||||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||||
global chat_collector
|
global chat_collector
|
||||||
|
|
||||||
history = state['history']
|
|
||||||
|
|
||||||
if state['mode'] == 'instruct':
|
if state['mode'] == 'instruct':
|
||||||
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||||
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
||||||
user_input += additional_context
|
user_input += additional_context
|
||||||
|
logger.info(f'\n\n=== === ===\nAdding the following new context:\n{additional_context}\n=== === ===\n')
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def make_single_exchange(id_):
|
def make_single_exchange(id_):
|
||||||
output = ''
|
output = ''
|
||||||
output += f"{state['name1']}: {history['internal'][id_][0]}\n"
|
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n"
|
||||||
output += f"{state['name2']}: {history['internal'][id_][1]}\n"
|
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
|
||||||
return output
|
return output
|
||||||
|
|
||||||
if len(history['internal']) > params['chunk_count'] and user_input != '':
|
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
|
||||||
chunks = []
|
chunks = []
|
||||||
hist_size = len(history['internal'])
|
hist_size = len(shared.history['internal'])
|
||||||
for i in range(hist_size-1):
|
for i in range(hist_size-1):
|
||||||
chunks.append(make_single_exchange(i))
|
chunks.append(make_single_exchange(i))
|
||||||
|
|
||||||
add_chunks_to_collector(chunks, chat_collector)
|
add_chunks_to_collector(chunks, chat_collector)
|
||||||
query = '\n'.join(history['internal'][-1] + [user_input])
|
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||||
try:
|
try:
|
||||||
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
|
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
|
||||||
additional_context = '\n'
|
additional_context = '\n'
|
||||||
for id_ in best_ids:
|
for id_ in best_ids:
|
||||||
if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
additional_context += make_single_exchange(id_)
|
additional_context += make_single_exchange(id_)
|
||||||
|
|
||||||
logger.warning(f'Adding the following new context:\n{additional_context}')
|
logger.warning(f'Adding the following new context:\n{additional_context}')
|
||||||
state['context'] = state['context'].strip() + '\n' + additional_context
|
state['context'] = state['context'].strip() + '\n' + additional_context
|
||||||
kwargs['history'] = {
|
kwargs['history'] = {
|
||||||
'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||||
'visible': ''
|
'visible': ''
|
||||||
}
|
}
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -240,6 +327,36 @@ def ui():
|
|||||||
file_input = gr.File(label='Input file', type='binary')
|
file_input = gr.File(label='Input file', type='binary')
|
||||||
update_file = gr.Button('Load data')
|
update_file = gr.Button('Load data')
|
||||||
|
|
||||||
|
with gr.Tab("Search input"):
|
||||||
|
search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB')
|
||||||
|
search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
|
||||||
|
semantic_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Require semantic similarity (not implemented)', info='Only download pages with similar titles/snippets to the search') # TODO cdg
|
||||||
|
semantic_requirement = gr.Slider(0, 1, value=params['time_weight'], label='Semantic similarity requirement (not implemented)', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') # TODO cdg
|
||||||
|
search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
|
||||||
|
update_search = gr.Button('Load data')
|
||||||
|
|
||||||
|
with gr.Accordion("Click for more information...", open=False):
|
||||||
|
gr.Markdown(textwrap.dedent("""
|
||||||
|
|
||||||
|
# installation/setup
|
||||||
|
Please follow the instruction found here to setup a custom search engine with Google.
|
||||||
|
https://www.thepythoncode.com/article/use-google-custom-search-engine-api-in-python
|
||||||
|
|
||||||
|
create a file called "custom_search_engine_keys.json"
|
||||||
|
|
||||||
|
Paste this text in it and replace with your values from the previous step:
|
||||||
|
"
|
||||||
|
{
|
||||||
|
"key": "Custom search engine key",
|
||||||
|
"cx": "Custom search engine cx number"
|
||||||
|
}
|
||||||
|
"
|
||||||
|
|
||||||
|
# usage
|
||||||
|
Enter a search query above. Press the load data button. This data will be added to the local chromaDB to be read into context at runtime.
|
||||||
|
|
||||||
|
"""))
|
||||||
|
|
||||||
with gr.Tab("Generation settings"):
|
with gr.Tab("Generation settings"):
|
||||||
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
|
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
|
||||||
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
|
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
|
||||||
@ -256,4 +373,5 @@ def ui():
|
|||||||
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||||
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
|
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
|
||||||
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||||
|
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated,show_progress=False)
|
||||||
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)
|
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user