mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Minor changes
This commit is contained in:
parent
7c37b82362
commit
6e648ca494
@ -1,8 +1,12 @@
|
||||
import json
|
||||
import re
|
||||
import textwrap
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from modules import chat, shared
|
||||
from modules.logging_colors import logger
|
||||
@ -10,11 +14,6 @@ from modules.logging_colors import logger
|
||||
from .chromadb import add_chunks_to_collector, make_collector
|
||||
from .download_urls import download_urls
|
||||
|
||||
import requests
|
||||
import json
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
params = {
|
||||
'chunk_count': 5,
|
||||
'chunk_count_initial': 10,
|
||||
@ -90,11 +89,13 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
||||
for i in feed_data_into_collector(all_text, chunk_len, chunk_sep):
|
||||
yield i
|
||||
|
||||
|
||||
def calculate_semantic_similarity(query_embedding, target_embedding):
|
||||
# Calculate cosine similarity between the query embedding and the target embedding
|
||||
similarity = cosine_similarity(query_embedding.reshape(1, -1), target_embedding.reshape(1, -1))
|
||||
return similarity[0][0]
|
||||
|
||||
|
||||
def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, semantic_cleanup, semantic_requirement, threads):
|
||||
# Load parameters from the config file
|
||||
with open('custom_search_engine_keys.json') as key_file:
|
||||
@ -148,7 +149,6 @@ def feed_search_into_collector(query, chunk_len, chunk_sep, strong_cleanup, sema
|
||||
if similarity_score < semantic_requirement:
|
||||
continue
|
||||
|
||||
|
||||
# extract the page url and add it to the urls to download
|
||||
link = search_item.get("link")
|
||||
urls += link + "\n"
|
||||
@ -173,6 +173,8 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
global chat_collector
|
||||
|
||||
history = state['history']
|
||||
|
||||
if state['mode'] == 'instruct':
|
||||
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
||||
@ -182,29 +184,29 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
|
||||
def make_single_exchange(id_):
|
||||
output = ''
|
||||
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n"
|
||||
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
|
||||
output += f"{state['name1']}: {history['internal'][id_][0]}\n"
|
||||
output += f"{state['name2']}: {history['internal'][id_][1]}\n"
|
||||
return output
|
||||
|
||||
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
|
||||
if len(history['internal']) > params['chunk_count'] and user_input != '':
|
||||
chunks = []
|
||||
hist_size = len(shared.history['internal'])
|
||||
for i in range(hist_size-1):
|
||||
hist_size = len(history['internal'])
|
||||
for i in range(hist_size - 1):
|
||||
chunks.append(make_single_exchange(i))
|
||||
|
||||
add_chunks_to_collector(chunks, chat_collector)
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
query = '\n'.join(history['internal'][-1] + [user_input])
|
||||
try:
|
||||
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
|
||||
additional_context = '\n'
|
||||
for id_ in best_ids:
|
||||
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
additional_context += make_single_exchange(id_)
|
||||
|
||||
logger.warning(f'Adding the following new context:\n{additional_context}')
|
||||
state['context'] = state['context'].strip() + '\n' + additional_context
|
||||
kwargs['history'] = {
|
||||
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||
'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||
'visible': ''
|
||||
}
|
||||
except RuntimeError:
|
||||
@ -370,5 +372,5 @@ def ui():
|
||||
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
|
||||
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated,show_progress=False)
|
||||
update_search.click(feed_search_into_collector, [search_term, chunk_len, chunk_sep, search_strong_cleanup, semantic_cleanup, semantic_requirement, search_threads], last_updated, show_progress=False)
|
||||
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)
|
||||
|
Loading…
Reference in New Issue
Block a user