Minor changes

This commit is contained in:
oobabooga 2023-07-13 14:15:17 -07:00
parent 7c37b82362
commit 6e648ca494

View File

@ -1,8 +1,12 @@
import json
import re
import textwrap
import gradio as gr
import requests
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from modules import chat, shared
from modules.logging_colors import logger
@ -10,11 +14,6 @@ from modules.logging_colors import logger
from .chromadb import add_chunks_to_collector, make_collector
from .download_urls import download_urls
import requests
import json
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
params = {
'chunk_count': 5,
'chunk_count_initial': 10,
@ -90,11 +89,13 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
yield i
def calculate_semantic_similarity(query_embedding, target_embedding):
# Calculate cosine similarity between the query embedding and the target embedding
similarity = cosine_similarity(query_embedding.reshape(1, -1), target_embedding.reshape(1, -1))
return similarity[0][0]
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
# Load parameters from the config file
with open('custom_search_engine_keys.json') as key_file:
@ -148,7 +149,6 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema
if similarity_score < semantic_requirement:
continue
# extract the page url and add it to the urls to download
link = search_item.get("link")
urls += link + "\n"
@ -173,6 +173,8 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
def custom_generate_chat_prompt(user_input, state, **kwargs):
global chat_collector
history = state['history']
if state['mode'] == 'instruct':
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
@ -182,29 +184,29 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
def make_single_exchange(id_):
output = ''
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n"
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
output += f"{state['name1']}: {history['internal'][id_][0]}\n"
output += f"{state['name2']}: {history['internal'][id_][1]}\n"
return output
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
if len(history['internal']) > params['chunk_count'] and user_input != '':
chunks = []
hist_size = len(shared.history['internal'])
for i in range(hist_size-1):
hist_size = len(history['internal'])
for i in range(hist_size - 1):
chunks.append(make_single_exchange(i))
add_chunks_to_collector(chunks, chat_collector)
query = '\n'.join(shared.history['internal'][-1] + [user_input])
query = '\n'.join(history['internal'][-1] + [user_input])
try:
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
additional_context = '\n'
for id_ in best_ids:
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
additional_context += make_single_exchange(id_)
logger.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context
kwargs['history'] = {
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': ''
}
except RuntimeError:
@ -370,5 +372,5 @@ def ui():
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated,show_progress=False)
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated, show_progress=False)
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)