mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Generalize superbooga to chat mode
This commit is contained in:
parent
ec1cda0e1f
commit
6b67cb6611
@ -8,9 +8,10 @@ import posthog
|
||||
import torch
|
||||
from bs4 import BeautifulSoup
|
||||
from chromadb.config import Settings
|
||||
from modules import shared
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from modules import chat, shared
|
||||
|
||||
print('Intercepting all calls to posthog :)')
|
||||
posthog.capture = lambda *args, **kwargs: None
|
||||
|
||||
@ -53,6 +54,10 @@ class ChromaCollector(Collecter):
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
||||
return result
|
||||
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||
return list(map(lambda x : int(x[2:]), result))
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
|
||||
@ -68,18 +73,24 @@ collector = ChromaCollector(embedder)
|
||||
chunk_count = 5
|
||||
|
||||
|
||||
def feed_data_into_collector(corpus, chunk_len):
|
||||
def add_chunks_to_collector(chunks):
|
||||
global collector
|
||||
chunk_len = int(chunk_len)
|
||||
collector.clear()
|
||||
collector.add(chunks)
|
||||
|
||||
|
||||
def feed_data_into_collector(corpus, chunk_len):
|
||||
# Defining variables
|
||||
chunk_len = int(chunk_len)
|
||||
cumulative = ''
|
||||
|
||||
# Breaking the data into chunks and adding those to the db
|
||||
cumulative += "Breaking the input dataset...\n\n"
|
||||
yield cumulative
|
||||
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
||||
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
|
||||
yield cumulative
|
||||
collector.clear()
|
||||
collector.add(data_chunks)
|
||||
add_chunks_to_collector(data_chunks)
|
||||
cumulative += "Done."
|
||||
yield cumulative
|
||||
|
||||
@ -123,6 +134,8 @@ def apply_settings(_chunk_count):
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
if shared.is_chat():
|
||||
return string
|
||||
|
||||
# Find the user input
|
||||
pattern = re.compile(r"<\|begin-user-input\|>(.*?)<\|end-user-input\|>", re.DOTALL)
|
||||
@ -143,6 +156,23 @@ def input_modifier(string):
|
||||
return string
|
||||
|
||||
|
||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
if len(shared.history['internal']) > 2 and user_input != '':
|
||||
chunks = []
|
||||
for i in range(len(shared.history['internal'])-1):
|
||||
chunks.append('\n'.join(shared.history['internal'][i]))
|
||||
|
||||
add_chunks_to_collector(chunks)
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
|
||||
|
||||
# Sort the history by relevance instead of by chronological order,
|
||||
# except for the latest message
|
||||
state['history'] = [shared.history['internal'][id_] for id_ in best_ids[::-1]] + [shared.history['internal'][-1]]
|
||||
|
||||
return chat.generate_chat_prompt(user_input, state, **kwargs)
|
||||
|
||||
|
||||
def ui():
|
||||
with gr.Accordion("Click for more information...", open=False):
|
||||
gr.Markdown(textwrap.dedent("""
|
||||
@ -156,7 +186,9 @@ def ui():
|
||||
|
||||
It is a modified version of the superbig extension by kaiokendev: https://github.com/kaiokendev/superbig
|
||||
|
||||
## How to use it
|
||||
## Notebook/default modes
|
||||
|
||||
### How to use it
|
||||
|
||||
1) Paste your input text (of whatever length) into the text box below.
|
||||
2) Click on "Load data" to feed this text into the Chroma database.
|
||||
@ -166,7 +198,7 @@ def ui():
|
||||
|
||||
The special tokens mentioned above (`<|begin-user-input|>`, `<|end-user-input|>`, and `<|injection-point|>`) are removed when the injection happens.
|
||||
|
||||
## Example
|
||||
### Example
|
||||
|
||||
For your convenience, you can use the following prompt as a starting point (for Alpaca models):
|
||||
|
||||
@ -186,14 +218,25 @@ def ui():
|
||||
### Response:
|
||||
```
|
||||
|
||||
## Chat mode
|
||||
|
||||
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
|
||||
|
||||
That is, the prompt will include (starting from the end):
|
||||
|
||||
* Your input
|
||||
* The latest input/reply pair
|
||||
* The #1 most relevant input/reply pair prior to the latest
|
||||
* The #2 most relevant input/reply pair prior to the latest
|
||||
* Etc
|
||||
|
||||
This way, the bot can have a long term history.
|
||||
|
||||
*This extension is currently experimental and under development.*
|
||||
|
||||
"""))
|
||||
|
||||
if shared.is_chat():
|
||||
# Chat mode has to be handled differently, probably using a custom_generate_chat_prompt
|
||||
pass
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Tab("Text input"):
|
||||
|
@ -27,6 +27,11 @@ def replace_all(text, dic):
|
||||
|
||||
|
||||
def generate_chat_prompt(user_input, state, **kwargs):
|
||||
# Check if an extension is sending its modified history.
|
||||
# If not, use the regular history
|
||||
history = state['history'] if 'history' in state else shared.history['internal']
|
||||
|
||||
# Define some variables
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
@ -61,14 +66,14 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
|
||||
|
||||
# Building the prompt
|
||||
i = len(shared.history['internal']) - 1
|
||||
i = len(history) - 1
|
||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
|
||||
if _continue and i == len(history) - 1:
|
||||
rows.insert(1, bot_turn_stripped + history[i][1].strip())
|
||||
else:
|
||||
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
|
||||
rows.insert(1, bot_turn.replace('<|bot-message|>', history[i][1].strip()))
|
||||
|
||||
string = shared.history['internal'][i][0]
|
||||
string = history[i][0]
|
||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
|
||||
|
||||
@ -80,7 +85,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
elif not _continue:
|
||||
# Adding the user message
|
||||
if len(user_input) > 0:
|
||||
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
|
||||
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
||||
|
Loading…
Reference in New Issue
Block a user