mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 09:40:20 +01:00
updated superbooga with search integration and semantic filtering
This commit is contained in:
parent
ab4ca9a3dd
commit
0480d70bfd
@ -12,6 +12,8 @@ from .download_urls import download_urls
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'chunk_count': 5,
|
'chunk_count': 5,
|
||||||
@ -20,6 +22,8 @@ params = {
|
|||||||
'chunk_length': 700,
|
'chunk_length': 700,
|
||||||
'chunk_separator': '',
|
'chunk_separator': '',
|
||||||
'strong_cleanup': False,
|
'strong_cleanup': False,
|
||||||
|
'semantic_cleanup': True,
|
||||||
|
'semantic_weight': 0.5,
|
||||||
'threads': 4,
|
'threads': 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +64,6 @@ def feed_file_into_collector(file, chunk_len, chunk_sep):
|
|||||||
|
|
||||||
|
|
||||||
def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads):
|
def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads):
|
||||||
print("feed_url_into_collector")
|
|
||||||
all_text = ''
|
all_text = ''
|
||||||
cumulative = ''
|
cumulative = ''
|
||||||
|
|
||||||
@ -87,13 +90,18 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
|||||||
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
def calculate_semantic_similarity(query_embedding, target_embedding):
|
||||||
|
# Calculate cosine similarity between the query embedding and the target embedding
|
||||||
|
similarity = cosine_similarity(query_embedding.reshape(1, -1), target_embedding.reshape(1, -1))
|
||||||
|
return similarity[0][0]
|
||||||
|
|
||||||
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
|
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
|
||||||
# Load parameters from the config file
|
# Load parameters from the config file
|
||||||
with open('custom_search_engine_keys.json') as key_file:
|
with open('custom_search_engine_keys.json') as key_file:
|
||||||
key = json.load(key_file)
|
key = json.load(key_file)
|
||||||
|
|
||||||
print("=== Searching google ===")
|
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
print("-- " + str(query))
|
query_embedding = model.encode([query])[0]
|
||||||
|
|
||||||
# Set up API endpoint and parameters
|
# Set up API endpoint and parameters
|
||||||
url = "https://www.googleapis.com/customsearch/v1"
|
url = "https://www.googleapis.com/customsearch/v1"
|
||||||
@ -113,12 +121,6 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema
|
|||||||
print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search")
|
print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search")
|
||||||
return query
|
return query
|
||||||
|
|
||||||
# constructing the URL
|
|
||||||
# doc: https://developers.google.com/custom-search/v1/using_rest
|
|
||||||
# calculating start, (page=2) => (start=11), (page=3) => (start=21)
|
|
||||||
page = 1
|
|
||||||
start = (page - 1) * 10 + 1
|
|
||||||
|
|
||||||
# Send API request
|
# Send API request
|
||||||
response = requests.get(url, params=params)
|
response = requests.get(url, params=params)
|
||||||
|
|
||||||
@ -127,49 +129,36 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema
|
|||||||
|
|
||||||
# get the result items
|
# get the result items
|
||||||
search_items = data.get("items")
|
search_items = data.get("items")
|
||||||
|
|
||||||
# iterate over 10 results found
|
# iterate over 10 results found
|
||||||
search_urls = ""
|
urls = ""
|
||||||
for i, search_item in enumerate(search_items, start=1):
|
for i, search_item in enumerate(search_items, start=1):
|
||||||
try:
|
if semantic_cleanup:
|
||||||
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
|
# get titles and descriptions and use that to semantically weight the search result
|
||||||
except KeyError:
|
|
||||||
long_description = "N/A"
|
|
||||||
# get the page title
|
# get the page title
|
||||||
title = search_item.get("title")
|
title = search_item.get("title")
|
||||||
# page snippet
|
# page snippet
|
||||||
snippet = search_item.get("snippet")
|
snippet = search_item.get("snippet")
|
||||||
# alternatively, you can get the HTML snippet (bolded keywords)
|
|
||||||
html_snippet = search_item.get("htmlSnippet")
|
target_sentence = str(title) + " " + str(snippet)
|
||||||
# extract the page url
|
target_embedding = model.encode([target_sentence])[0]
|
||||||
|
|
||||||
|
similarity_score = calculate_semantic_similarity(query_embedding, target_embedding)
|
||||||
|
|
||||||
|
if similarity_score < semantic_requirement:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
# extract the page url and add it to the urls to download
|
||||||
link = search_item.get("link")
|
link = search_item.get("link")
|
||||||
search_urls += link + "\n"
|
urls += link + "\n"
|
||||||
|
|
||||||
# TODO don't clone feed_url_into_collector
|
# Call the original feed_url_into_collector function instead of duplicating the code
|
||||||
all_text = ''
|
result_generator = feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
||||||
cumulative = ''
|
|
||||||
|
|
||||||
urls = search_urls.strip().split('\n')
|
# Consume the yielded values
|
||||||
cumulative += f'Loading {len(urls)} URLs with {threads} threads...\n\n'
|
for result in result_generator:
|
||||||
yield cumulative
|
yield result
|
||||||
for update, contents in download_urls(urls, threads=threads):
|
|
||||||
yield cumulative + update
|
|
||||||
|
|
||||||
cumulative += 'Processing the HTML sources...'
|
|
||||||
yield cumulative
|
|
||||||
for content in contents:
|
|
||||||
soup = BeautifulSoup(content, features="html.parser")
|
|
||||||
for script in soup(["script", "style"]):
|
|
||||||
script.extract()
|
|
||||||
|
|
||||||
strings = soup.stripped_strings
|
|
||||||
if strong_cleanup:
|
|
||||||
strings = [s for s in strings if re.search("[A-Za-z] ", s)]
|
|
||||||
|
|
||||||
text = '\n'.join([s.strip() for s in strings])
|
|
||||||
all_text += text
|
|
||||||
|
|
||||||
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
|
||||||
yield i
|
|
||||||
|
|
||||||
|
|
||||||
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||||
@ -330,8 +319,10 @@ def ui():
|
|||||||
with gr.Tab("Search input"):
|
with gr.Tab("Search input"):
|
||||||
search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB')
|
search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB')
|
||||||
search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
|
search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
|
||||||
semantic_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Require semantic similarity (not implemented)', info='Only download pages with similar titles/snippets to the search') # TODO cdg
|
|
||||||
semantic_requirement = gr.Slider(0, 1, value=params['time_weight'], label='Semantic similarity requirement (not implemented)', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') # TODO cdg
|
semantic_cleanup = gr.Checkbox(value=params['semantic_cleanup'], label='Require semantic similarity', info='Only download pages with similar titles/snippets to the search based on a semantic search')
|
||||||
|
semantic_requirement = gr.Slider(0, 1, value=params['semantic_weight'], label='Semantic similarity requirement', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.')
|
||||||
|
|
||||||
search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
|
search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
|
||||||
update_search = gr.Button('Load data')
|
update_search = gr.Button('Load data')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user