mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 19:09:32 +01:00
updated superbooga with search integration and semantic filtering
This commit is contained in:
parent
ab4ca9a3dd
commit
0480d70bfd
@ -12,6 +12,8 @@ from .download_urls import download_urls
|
||||
|
||||
import requests
|
||||
import json
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
params = {
|
||||
'chunk_count': 5,
|
||||
@ -20,6 +22,8 @@ params = {
|
||||
'chunk_length': 700,
|
||||
'chunk_separator': '',
|
||||
'strong_cleanup': False,
|
||||
'semantic_cleanup': True,
|
||||
'semantic_weight': 0.5,
|
||||
'threads': 4,
|
||||
}
|
||||
|
||||
@ -60,7 +64,6 @@ def feed_file_into_collector(file, chunk_len, chunk_sep):
|
||||
|
||||
|
||||
def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads):
|
||||
print("feed_url_into_collector")
|
||||
all_text = ''
|
||||
cumulative = ''
|
||||
|
||||
@ -87,13 +90,18 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
||||
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
||||
yield i
|
||||
|
||||
def calculate_semantic_similarity(query_embedding, target_embedding):
|
||||
# Calculate cosine similarity between the query embedding and the target embedding
|
||||
similarity = cosine_similarity(query_embedding.reshape(1, -1), target_embedding.reshape(1, -1))
|
||||
return similarity[0][0]
|
||||
|
||||
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
|
||||
# Load parameters from the config file
|
||||
with open('custom_search_engine_keys.json') as key_file:
|
||||
key = json.load(key_file)
|
||||
|
||||
print("=== Searching google ===")
|
||||
print("-- " + str(query))
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
query_embedding = model.encode([query])[0]
|
||||
|
||||
# Set up API endpoint and parameters
|
||||
url = "https://www.googleapis.com/customsearch/v1"
|
||||
@ -113,12 +121,6 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema
|
||||
print("You need to provide an CSE ID, by modifying the script.py in oobabooga_windows \ text-generation-webui.\nSkipping search")
|
||||
return query
|
||||
|
||||
# constructing the URL
|
||||
# doc: https://developers.google.com/custom-search/v1/using_rest
|
||||
# calculating start, (page=2) => (start=11), (page=3) => (start=21)
|
||||
page = 1
|
||||
start = (page - 1) * 10 + 1
|
||||
|
||||
# Send API request
|
||||
response = requests.get(url, params=params)
|
||||
|
||||
@ -127,49 +129,36 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema
|
||||
|
||||
# get the result items
|
||||
search_items = data.get("items")
|
||||
|
||||
# iterate over 10 results found
|
||||
search_urls = ""
|
||||
urls = ""
|
||||
for i, search_item in enumerate(search_items, start=1):
|
||||
try:
|
||||
long_description = search_item["pagemap"]["metatags"][0]["og:description"]
|
||||
except KeyError:
|
||||
long_description = "N/A"
|
||||
# get the page title
|
||||
title = search_item.get("title")
|
||||
# page snippet
|
||||
snippet = search_item.get("snippet")
|
||||
# alternatively, you can get the HTML snippet (bolded keywords)
|
||||
html_snippet = search_item.get("htmlSnippet")
|
||||
# extract the page url
|
||||
if semantic_cleanup:
|
||||
# get titles and descriptions and use that to semantically weight the search result
|
||||
# get the page title
|
||||
title = search_item.get("title")
|
||||
# page snippet
|
||||
snippet = search_item.get("snippet")
|
||||
|
||||
target_sentence = str(title) + " " + str(snippet)
|
||||
target_embedding = model.encode([target_sentence])[0]
|
||||
|
||||
similarity_score = calculate_semantic_similarity(query_embedding, target_embedding)
|
||||
|
||||
if similarity_score < semantic_requirement:
|
||||
continue
|
||||
|
||||
|
||||
# extract the page url and add it to the urls to download
|
||||
link = search_item.get("link")
|
||||
search_urls += link + "\n"
|
||||
urls += link + "\n"
|
||||
|
||||
# TODO don't clone feed_url_into_collector
|
||||
all_text = ''
|
||||
cumulative = ''
|
||||
# Call the original feed_url_into_collector function instead of duplicating the code
|
||||
result_generator = feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
||||
|
||||
urls = search_urls.strip().split('\n')
|
||||
cumulative += f'Loading {len(urls)} URLs with {threads} threads...\n\n'
|
||||
yield cumulative
|
||||
for update, contents in download_urls(urls, threads=threads):
|
||||
yield cumulative + update
|
||||
|
||||
cumulative += 'Processing the HTML sources...'
|
||||
yield cumulative
|
||||
for content in contents:
|
||||
soup = BeautifulSoup(content, features="html.parser")
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
strings = soup.stripped_strings
|
||||
if strong_cleanup:
|
||||
strings = [s for s in strings if re.search("[A-Za-z] ", s)]
|
||||
|
||||
text = '\n'.join([s.strip() for s in strings])
|
||||
all_text += text
|
||||
|
||||
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
||||
yield i
|
||||
# Consume the yielded values
|
||||
for result in result_generator:
|
||||
yield result
|
||||
|
||||
|
||||
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||
@ -330,8 +319,10 @@ def ui():
|
||||
with gr.Tab("Search input"):
|
||||
search_term = gr.Textbox(lines=1, label='Search Input', info='Enter a google search, returned results will be fed into the DB')
|
||||
search_strong_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
|
||||
semantic_cleanup = gr.Checkbox(value=params['strong_cleanup'], label='Require semantic similarity (not implemented)', info='Only download pages with similar titles/snippets to the search') # TODO cdg
|
||||
semantic_requirement = gr.Slider(0, 1, value=params['time_weight'], label='Semantic similarity requirement (not implemented)', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.') # TODO cdg
|
||||
|
||||
semantic_cleanup = gr.Checkbox(value=params['semantic_cleanup'], label='Require semantic similarity', info='Only download pages with similar titles/snippets to the search based on a semantic search')
|
||||
semantic_requirement = gr.Slider(0, 1, value=params['semantic_weight'], label='Semantic similarity requirement', info='Defines the requirement of the semantic search. 0 = no culling of dissimilar pages.')
|
||||
|
||||
search_threads = gr.Number(value=params['threads'], label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
|
||||
update_search = gr.Button('Load data')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user