mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Generalize superbooga to chat mode
This commit is contained in:
parent
ec1cda0e1f
commit
6b67cb6611
@ -8,9 +8,10 @@ import posthog
|
|||||||
import torch
|
import torch
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from modules import shared
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from modules import chat, shared
|
||||||
|
|
||||||
print('Intercepting all calls to posthog :)')
|
print('Intercepting all calls to posthog :)')
|
||||||
posthog.capture = lambda *args, **kwargs: None
|
posthog.capture = lambda *args, **kwargs: None
|
||||||
|
|
||||||
@ -53,6 +54,10 @@ class ChromaCollector(Collecter):
|
|||||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||||
|
return list(map(lambda x : int(x[2:]), result))
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.collection.delete(ids=self.ids)
|
self.collection.delete(ids=self.ids)
|
||||||
|
|
||||||
@ -68,18 +73,24 @@ collector = ChromaCollector(embedder)
|
|||||||
chunk_count = 5
|
chunk_count = 5
|
||||||
|
|
||||||
|
|
||||||
def feed_data_into_collector(corpus, chunk_len):
|
def add_chunks_to_collector(chunks):
|
||||||
global collector
|
global collector
|
||||||
chunk_len = int(chunk_len)
|
collector.clear()
|
||||||
|
collector.add(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def feed_data_into_collector(corpus, chunk_len):
|
||||||
|
# Defining variables
|
||||||
|
chunk_len = int(chunk_len)
|
||||||
cumulative = ''
|
cumulative = ''
|
||||||
|
|
||||||
|
# Breaking the data into chunks and adding those to the db
|
||||||
cumulative += "Breaking the input dataset...\n\n"
|
cumulative += "Breaking the input dataset...\n\n"
|
||||||
yield cumulative
|
yield cumulative
|
||||||
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
||||||
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
|
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
|
||||||
yield cumulative
|
yield cumulative
|
||||||
collector.clear()
|
add_chunks_to_collector(data_chunks)
|
||||||
collector.add(data_chunks)
|
|
||||||
cumulative += "Done."
|
cumulative += "Done."
|
||||||
yield cumulative
|
yield cumulative
|
||||||
|
|
||||||
@ -123,6 +134,8 @@ def apply_settings(_chunk_count):
|
|||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
|
if shared.is_chat():
|
||||||
|
return string
|
||||||
|
|
||||||
# Find the user input
|
# Find the user input
|
||||||
pattern = re.compile(r"<\|begin-user-input\|>(.*?)<\|end-user-input\|>", re.DOTALL)
|
pattern = re.compile(r"<\|begin-user-input\|>(.*?)<\|end-user-input\|>", re.DOTALL)
|
||||||
@ -143,6 +156,23 @@ def input_modifier(string):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||||
|
if len(shared.history['internal']) > 2 and user_input != '':
|
||||||
|
chunks = []
|
||||||
|
for i in range(len(shared.history['internal'])-1):
|
||||||
|
chunks.append('\n'.join(shared.history['internal'][i]))
|
||||||
|
|
||||||
|
add_chunks_to_collector(chunks)
|
||||||
|
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||||
|
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
|
||||||
|
|
||||||
|
# Sort the history by relevance instead of by chronological order,
|
||||||
|
# except for the latest message
|
||||||
|
state['history'] = [shared.history['internal'][id_] for id_ in best_ids[::-1]] + [shared.history['internal'][-1]]
|
||||||
|
|
||||||
|
return chat.generate_chat_prompt(user_input, state, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
with gr.Accordion("Click for more information...", open=False):
|
with gr.Accordion("Click for more information...", open=False):
|
||||||
gr.Markdown(textwrap.dedent("""
|
gr.Markdown(textwrap.dedent("""
|
||||||
@ -156,7 +186,9 @@ def ui():
|
|||||||
|
|
||||||
It is a modified version of the superbig extension by kaiokendev: https://github.com/kaiokendev/superbig
|
It is a modified version of the superbig extension by kaiokendev: https://github.com/kaiokendev/superbig
|
||||||
|
|
||||||
## How to use it
|
## Notebook/default modes
|
||||||
|
|
||||||
|
### How to use it
|
||||||
|
|
||||||
1) Paste your input text (of whatever length) into the text box below.
|
1) Paste your input text (of whatever length) into the text box below.
|
||||||
2) Click on "Load data" to feed this text into the Chroma database.
|
2) Click on "Load data" to feed this text into the Chroma database.
|
||||||
@ -166,7 +198,7 @@ def ui():
|
|||||||
|
|
||||||
The special tokens mentioned above (`<|begin-user-input|>`, `<|end-user-input|>`, and `<|injection-point|>`) are removed when the injection happens.
|
The special tokens mentioned above (`<|begin-user-input|>`, `<|end-user-input|>`, and `<|injection-point|>`) are removed when the injection happens.
|
||||||
|
|
||||||
## Example
|
### Example
|
||||||
|
|
||||||
For your convenience, you can use the following prompt as a starting point (for Alpaca models):
|
For your convenience, you can use the following prompt as a starting point (for Alpaca models):
|
||||||
|
|
||||||
@ -186,14 +218,25 @@ def ui():
|
|||||||
### Response:
|
### Response:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Chat mode
|
||||||
|
|
||||||
|
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
|
||||||
|
|
||||||
|
That is, the prompt will include (starting from the end):
|
||||||
|
|
||||||
|
* Your input
|
||||||
|
* The latest input/reply pair
|
||||||
|
* The #1 most relevant input/reply pair prior to the latest
|
||||||
|
* The #2 most relevant input/reply pair prior to the latest
|
||||||
|
* Etc
|
||||||
|
|
||||||
|
This way, the bot can have a long term history.
|
||||||
|
|
||||||
*This extension is currently experimental and under development.*
|
*This extension is currently experimental and under development.*
|
||||||
|
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
if shared.is_chat():
|
if not shared.is_chat():
|
||||||
# Chat mode has to be handled differently, probably using a custom_generate_chat_prompt
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Tab("Text input"):
|
with gr.Tab("Text input"):
|
||||||
|
@ -27,6 +27,11 @@ def replace_all(text, dic):
|
|||||||
|
|
||||||
|
|
||||||
def generate_chat_prompt(user_input, state, **kwargs):
|
def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
|
# Check if an extension is sending its modified history.
|
||||||
|
# If not, use the regular history
|
||||||
|
history = state['history'] if 'history' in state else shared.history['internal']
|
||||||
|
|
||||||
|
# Define some variables
|
||||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
@ -61,14 +66,14 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
|
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
|
||||||
|
|
||||||
# Building the prompt
|
# Building the prompt
|
||||||
i = len(shared.history['internal']) - 1
|
i = len(history) - 1
|
||||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||||
if _continue and i == len(shared.history['internal']) - 1:
|
if _continue and i == len(history) - 1:
|
||||||
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
|
rows.insert(1, bot_turn_stripped + history[i][1].strip())
|
||||||
else:
|
else:
|
||||||
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
|
rows.insert(1, bot_turn.replace('<|bot-message|>', history[i][1].strip()))
|
||||||
|
|
||||||
string = shared.history['internal'][i][0]
|
string = history[i][0]
|
||||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||||
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
|
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
|
||||||
|
|
||||||
@ -80,7 +85,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
elif not _continue:
|
elif not _continue:
|
||||||
# Adding the user message
|
# Adding the user message
|
||||||
if len(user_input) > 0:
|
if len(user_input) > 0:
|
||||||
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
|
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
|
||||||
|
|
||||||
# Adding the Character prefix
|
# Adding the Character prefix
|
||||||
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
||||||
|
Loading…
Reference in New Issue
Block a user