mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-30 06:00:15 +01:00
Make chunk length/count customizable
This commit is contained in:
parent
8c06eeaf84
commit
04eca9b65b
@ -65,13 +65,15 @@ class SentenceTransformerEmbedder(Embedder):
|
||||
|
||||
embedder = SentenceTransformerEmbedder()
|
||||
collector = ChromaCollector(embedder)
|
||||
chunk_count = 5
|
||||
|
||||
|
||||
def feed_data_into_collector(corpus):
|
||||
global collector
|
||||
def feed_data_into_collector(corpus, chunk_len, _chunk_count):
|
||||
global collector, chunk_count
|
||||
chunk_count = int(_chunk_count)
|
||||
chunk_len = int(chunk_len)
|
||||
|
||||
cumulative = ''
|
||||
chunk_len = 700
|
||||
cumulative += "Breaking the input dataset...\n\n"
|
||||
yield cumulative
|
||||
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
||||
@ -83,14 +85,14 @@ def feed_data_into_collector(corpus):
|
||||
yield cumulative
|
||||
|
||||
|
||||
def feed_file_into_collector(file):
|
||||
def feed_file_into_collector(file, chunk_len, chunk_count):
|
||||
yield 'Reading the input dataset...\n\n'
|
||||
text = file.decode('utf-8')
|
||||
for i in feed_data_into_collector(text):
|
||||
for i in feed_data_into_collector(text, chunk_len, chunk_count):
|
||||
yield i
|
||||
|
||||
|
||||
def feed_url_into_collector(url):
|
||||
def feed_url_into_collector(url, chunk_len, chunk_count):
|
||||
yield 'Loading the URL...'
|
||||
html = urlopen(url).read()
|
||||
soup = BeautifulSoup(html, features="html.parser")
|
||||
@ -101,7 +103,7 @@ def feed_url_into_collector(url):
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = '\n\n'.join(chunk for chunk in chunks if chunk)
|
||||
for i in feed_data_into_collector(text):
|
||||
for i in feed_data_into_collector(text, chunk_len, chunk_count):
|
||||
yield i
|
||||
|
||||
|
||||
@ -115,8 +117,8 @@ def input_modifier(string):
|
||||
else:
|
||||
user_input = ''
|
||||
|
||||
# Get the 5 most similar chunks
|
||||
results = collector.get(user_input, n_results=5)
|
||||
# Get the most similar chunks
|
||||
results = collector.get(user_input, n_results=chunk_count)
|
||||
|
||||
# Make the replacements
|
||||
string = string.replace('<|begin-user-input|>', '')
|
||||
@ -178,9 +180,13 @@ def ui():
|
||||
file_input = gr.File(label='Input file', type='binary')
|
||||
update_file = gr.Button('Apply')
|
||||
|
||||
with gr.Row():
|
||||
chunk_len = gr.Number(value=700, label='Chunk length', info='In characters, not tokens')
|
||||
chunk_count = gr.Number(value=5, label='Chunk count', info='The number of closest-matching chunks to include in the prompt')
|
||||
|
||||
with gr.Column():
|
||||
last_updated = gr.Markdown()
|
||||
|
||||
update_data.click(feed_data_into_collector, data_input, last_updated, show_progress=False)
|
||||
update_url.click(feed_url_into_collector, url_input, last_updated, show_progress=False)
|
||||
update_file.click(feed_file_into_collector, file_input, last_updated, show_progress=False)
|
||||
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_count], last_updated, show_progress=False)
|
||||
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_count], last_updated, show_progress=False)
|
||||
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_count], last_updated, show_progress=False)
|
||||
|
Loading…
Reference in New Issue
Block a user