From 6e648ca494d0bf86b433ca89d9defd0c7af39e77 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 13 Jul 2023 14:15:17 -0700 Subject: [PATCH] Minor changes --- extensions/superbooga/script.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index fee1e56b..54309671 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -1,8 +1,12 @@ +import json import re import textwrap import gradio as gr +import requests from bs4 import BeautifulSoup +from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity from modules import chat, shared from modules.logging_colors import logger @@ -10,11 +14,6 @@ from modules.logging_colors import logger from .chromadb import add_chunks_to_collector, make_collector from .download_urls import download_urls -import requests -import json -from sentence_transformers import SentenceTransformer -from sklearn.metrics.pairwise import cosine_similarity - params = { 'chunk_count': 5, 'chunk_count_initial': 10, @@ -90,11 +89,13 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads) for i in feed_data_into_collector(all_text, chunk_len, chunk_sep): yield i + def calculate_semantic_similarity(query_embedding, target_embedding): # Calculate cosine similarity between the query embedding and the target embedding similarity = cosine_similarity(query_embedding.reshape(1, -1), target_embedding.reshape(1, -1)) return similarity[0][0] + def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads): # Load parameters from the config file with open('custom_search_engine_keys.json') as key_file: @@ -148,7 +149,6 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema if similarity_score < semantic_requirement: continue - # extract the page url and add it to the urls to download link = search_item.get("link") urls += link + "\n" @@ -173,6 +173,8 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight): def custom_generate_chat_prompt(user_input, state, **kwargs): global chat_collector + history = state['history'] + if state['mode'] == 'instruct': results = collector.get_sorted(user_input, n_results=params['chunk_count']) additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results) @@ -182,29 +184,29 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): def make_single_exchange(id_): output = '' - output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n" - output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n" + output += f"{state['name1']}: {history['internal'][id_][0]}\n" + output += f"{state['name2']}: {history['internal'][id_][1]}\n" return output - if len(shared.history['internal']) > params['chunk_count'] and user_input != '': + if len(history['internal']) > params['chunk_count'] and user_input != '': chunks = [] - hist_size = len(shared.history['internal']) - for i in range(hist_size-1): + hist_size = len(history['internal']) + for i in range(hist_size - 1): chunks.append(make_single_exchange(i)) add_chunks_to_collector(chunks, chat_collector) - query = '\n'.join(shared.history['internal'][-1] + [user_input]) + query = '\n'.join(history['internal'][-1] + [user_input]) try: best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight']) additional_context = '\n' for id_ in best_ids: - if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': + if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': additional_context += make_single_exchange(id_) logger.warning(f'Adding the following new context:\n{additional_context}') state['context'] = state['context'].strip() + '\n' + additional_context kwargs['history'] = { - 'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids], + 'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids], 'visible': '' } except RuntimeError: @@ -370,5 +372,5 @@ def ui(): update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False) update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False) update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False) - update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated,show_progress=False) + update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated, show_progress=False) update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)