Add support for extensions

This is experimental.
This commit is contained in:
oobabooga 2023-01-27 00:40:39 -03:00
parent 414fa9d161
commit 6b5dcd46c5
3 changed files with 102 additions and 54 deletions

View File

@ -133,6 +133,7 @@ Optionally, you can use the following command-line flags:
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.| | `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
| `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.| | `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.|
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
| `--extensions EXTENSIONS` | The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this". |
| `--listen` | Make the web UI reachable from your local network.| | `--listen` | Make the web UI reachable from your local network.|
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | | `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--verbose` | Print the prompts to the terminal. | | `--verbose` | Print the prompts to the terminal. |

View File

@ -0,0 +1,14 @@
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
return string.replace(' ', '#')
def output_modifier(string):
"""
This function is applied to the model outputs.
"""
return string.replace(' ', '_')

141
server.py
View File

@ -5,6 +5,7 @@ import glob
import torch import torch
import argparse import argparse
import json import json
import sys
from sys import exit from sys import exit
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
@ -32,6 +33,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.') parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
@ -165,6 +167,9 @@ def formatted_outputs(reply, model_name):
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None): def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
global model, tokenizer, model_name, loaded_preset, preset global model, tokenizer, model_name, loaded_preset, preset
original_question = question
if not (args.chat or args.cai_chat):
question = apply_extensions(question, "input")
if args.verbose: if args.verbose:
print(f"\n\n{question}\n--------------------\n") print(f"\n\n{question}\n--------------------\n")
@ -203,20 +208,36 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
reply = decode(output[0]) reply = decode(output[0])
t1 = time.time() t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)") print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
if not (args.chat or args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name) yield formatted_outputs(reply, model_name)
# Generate the reply 1 token at a time # Generate the reply 1 token at a time
else: else:
yield formatted_outputs(question, model_name) yield formatted_outputs(original_question, model_name)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8') preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
for i in tqdm(range(tokens//8+1)): for i in tqdm(range(tokens//8+1)):
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
reply = decode(output[0]) reply = decode(output[0])
if not (args.chat or args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name) yield formatted_outputs(reply, model_name)
input_ids = output input_ids = output
if output[0][-1] == n: if output[0][-1] == n:
break break
def apply_extensions(text, typ):
global available_extensions, extension_state
for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
if extension_state[ext][0] == True:
ext_string = f"extensions.{ext}.script"
exec(f"import {ext_string}")
if typ == "input":
text = eval(f"{ext_string}.input_modifier(text)")
else:
text = eval(f"{ext_string}.output_modifier(text)")
return text
def get_available_models(): def get_available_models():
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower) return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
@ -226,9 +247,19 @@ def get_available_presets():
def get_available_characters(): def get_available_characters():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
def get_available_extensions():
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
available_models = get_available_models() available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
available_characters = get_available_characters() available_characters = get_available_characters()
available_extensions = get_available_extensions()
extension_state = {}
if args.extensions is not None:
for i,ext in enumerate(args.extensions.split(',')):
if ext in available_extensions:
print(f'The extension "{ext}" is enabled.')
extension_state[ext] = [True, i]
# Choosing the default model # Choosing the default model
if args.model is not None: if args.model is not None:
@ -256,7 +287,7 @@ description = f"\n\n# Text generation lab\nGenerate text using Large Language Mo
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}" css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
if args.chat or args.cai_chat: if args.chat or args.cai_chat:
history = [] history = {'internal': [], 'visible': []}
character = None character = None
# This gets the new line characters right. # This gets the new line characters right.
@ -270,13 +301,13 @@ if args.chat or args.cai_chat:
text = clean_chat_message(text) text = clean_chat_message(text)
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
i = len(history)-1 i = len(history['internal'])-1
count = 0 count = 0
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens: while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
rows.insert(1, f"{name2}: {history[i][1].strip()}\n") rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
count += 1 count += 1
if not (history[i][0] == '<|BEGIN-VISIBLE-CHAT|>'): if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
rows.insert(1, f"{name1}: {history[i][0].strip()}\n") rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
count += 1 count += 1
i -= 1 i -= 1
if history_size != 0 and count >= history_size: if history_size != 0 and count >= history_size:
@ -291,18 +322,12 @@ if args.chat or args.cai_chat:
question = ''.join(rows) question = ''.join(rows)
return question return question
def remove_example_dialogue_from_history(history):
_history = copy.deepcopy(history)
for i in range(len(_history)):
if '<|BEGIN-VISIBLE-CHAT|>' in _history[i][0]:
_history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
_history = _history[i:]
break
return _history
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
original_text = text
text = apply_extensions(text, "input")
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
history.append(['', '']) history['internal'].append(['', ''])
history['visible'].append(['', ''])
eos_token = '\n' if check else None eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"): for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
next_character_found = False next_character_found = False
@ -312,7 +337,6 @@ if args.chat or args.cai_chat:
idx = idx[len(previous_idx)-1] idx = idx[len(previous_idx)-1]
reply = reply[idx + len(f"\n{name2}:"):] reply = reply[idx + len(f"\n{name2}:"):]
if check: if check:
reply = reply.split('\n')[0].strip() reply = reply.split('\n')[0].strip()
else: else:
@ -322,7 +346,8 @@ if args.chat or args.cai_chat:
next_character_found = True next_character_found = True
reply = clean_chat_message(reply) reply = clean_chat_message(reply)
history[-1] = [text, reply] history['internal'][-1] = [text, reply]
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
if next_character_found: if next_character_found:
break break
@ -335,16 +360,17 @@ if args.chat or args.cai_chat:
next_character_substring_found = True next_character_substring_found = True
if not next_character_substring_found: if not next_character_substring_found:
yield remove_example_dialogue_from_history(history) yield history['visible']
yield remove_example_dialogue_from_history(history) yield history['visible']
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
yield generate_chat_html(history, name1, name2, character) yield generate_chat_html(_history, name1, name2, character)
def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
last = history.pop() last = history['internal'].pop()
history['visible'].pop()
text = last[0] text = last[0]
if args.cai_chat: if args.cai_chat:
for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
@ -354,12 +380,15 @@ if args.chat or args.cai_chat:
yield i yield i
def remove_last_message(name1, name2): def remove_last_message(name1, name2):
last = history.pop() if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
_history = remove_example_dialogue_from_history(history) last = history['visible'].pop()
if args.cai_chat: history['internal'].pop()
return generate_chat_html(_history, name1, name2, character), last[0]
else: else:
return _history, last[0] last = ['', '']
if args.cai_chat:
return generate_chat_html(history['visible'], name1, name2, character), last[0]
else:
return history['visible'], last[0]
def clear_html(): def clear_html():
return generate_chat_html([], "", "", character) return generate_chat_html([], "", "", character)
@ -367,28 +396,31 @@ if args.chat or args.cai_chat:
def clear_chat_log(_character, name1, name2): def clear_chat_log(_character, name1, name2):
global history global history
if _character != 'None': if _character != 'None':
load_character(_character, name1, name2) for i in range(len(history['internal'])):
if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]:
history['visible'] = [['', history['internal'][i][1]]]
history['internal'] = history['internal'][:i+1]
break
else: else:
history = [] history['internal'] = []
_history = remove_example_dialogue_from_history(history) history['visible'] = []
if args.cai_chat: if args.cai_chat:
return generate_chat_html(_history, name1, name2, character) return generate_chat_html(history['visible'], name1, name2, character)
else: else:
return _history return history['visible']
def redraw_html(name1, name2): def redraw_html(name1, name2):
global history global history
_history = remove_example_dialogue_from_history(history) return generate_chat_html(history['visible'], name1, name2, character)
return generate_chat_html(_history, name1, name2, character)
def tokenize_dialogue(dialogue, name1, name2): def tokenize_dialogue(dialogue, name1, name2):
history = [] _history = []
dialogue = re.sub('<START>', '', dialogue) dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)] idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
if len(idx) == 0: if len(idx) == 0:
return history return _history
messages = [] messages = []
for i in range(len(idx)-1): for i in range(len(idx)-1):
@ -402,16 +434,16 @@ if args.chat or args.cai_chat:
elif i.startswith(f'{name2}:'): elif i.startswith(f'{name2}:'):
entry[1] = i[len(f'{name2}:'):].strip() entry[1] = i[len(f'{name2}:'):].strip()
if not (len(entry[0]) == 0 and len(entry[1]) == 0): if not (len(entry[0]) == 0 and len(entry[1]) == 0):
history.append(entry) _history.append(entry)
entry = ['', ''] entry = ['', '']
return history return _history
def save_history(): def save_history():
if not Path('logs').exists(): if not Path('logs').exists():
Path('logs').mkdir() Path('logs').mkdir()
with open(Path('logs/conversation.json'), 'w') as f: with open(Path('logs/conversation.json'), 'w') as f:
f.write(json.dumps({'data': history}, indent=2)) f.write(json.dumps({'data': history['internal']}, indent=2))
return Path('logs/conversation.json') return Path('logs/conversation.json')
def upload_history(file, name1, name2): def upload_history(file, name1, name2):
@ -420,21 +452,22 @@ if args.chat or args.cai_chat:
try: try:
j = json.loads(file) j = json.loads(file)
if 'data' in j: if 'data' in j:
history = j['data'] history['internal'] = j['data']
# Compatibility with Pygmalion AI's official web UI # Compatibility with Pygmalion AI's official web UI
elif 'chat' in j: elif 'chat' in j:
history = [':'.join(x.split(':')[1:]).strip() for x in j['chat']] history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'): if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
history = [['<|BEGIN-VISIBLE-CHAT|>', history[0]]] + [[history[i], history[i+1]] for i in range(1, len(history)-1, 2)] history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)]
else: else:
history = [[history[i], history[i+1]] for i in range(0, len(history)-1, 2)] history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)]
except: except:
history = tokenize_dialogue(file, name1, name2) history['internal'] = tokenize_dialogue(file, name1, name2)
def load_character(_character, name1, name2): def load_character(_character, name1, name2):
global history, character global history, character
context = "" context = ""
history = [] history['internal'] = []
history['visible'] = []
if _character != 'None': if _character != 'None':
character = _character character = _character
with open(Path(f'characters/{_character}.json'), 'r') as f: with open(Path(f'characters/{_character}.json'), 'r') as f:
@ -446,24 +479,24 @@ if args.chat or args.cai_chat:
context += f"Scenario: {data['world_scenario']}\n" context += f"Scenario: {data['world_scenario']}\n"
context = f"{context.strip()}\n<START>\n" context = f"{context.strip()}\n<START>\n"
if 'example_dialogue' in data and data['example_dialogue'] != '': if 'example_dialogue' in data and data['example_dialogue'] != '':
history = tokenize_dialogue(data['example_dialogue'], name1, name2) history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0: if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]] history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
history['visible'] += [['', data['char_greeting']]]
else: else:
history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]] history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
history['visible'] += [['', "Hello there!"]]
else: else:
character = None character = None
context = settings['context_pygmalion'] context = settings['context_pygmalion']
name2 = settings['name2_pygmalion'] name2 = settings['name2_pygmalion']
_history = remove_example_dialogue_from_history(history)
if args.cai_chat: if args.cai_chat:
return name2, context, generate_chat_html(_history, name1, name2, character) return name2, context, generate_chat_html(history['visible'], name1, name2, character)
else: else:
return name2, context, _history return name2, context, history['visible']
def upload_character(file, name1, name2): def upload_character(file, name1, name2):
global history
file = file.decode('utf-8') file = file.decode('utf-8')
data = json.loads(file) data = json.loads(file)
outfile_name = data["char_name"] outfile_name = data["char_name"]
@ -543,7 +576,7 @@ if args.chat or args.cai_chat:
if args.cai_chat: if args.cai_chat:
upload.upload(redraw_html, [name1, name2], [display1]) upload.upload(redraw_html, [name1, name2], [display1])
else: else:
upload.upload(lambda : remove_example_dialogue_from_history(history), [], [display1]) upload.upload(lambda : history['visible'], [], [display1])
elif args.notebook: elif args.notebook:
with gr.Blocks(css=css, analytics_enabled=False) as interface: with gr.Blocks(css=css, analytics_enabled=False) as interface: