mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-24 18:49:23 +01:00
Add support for extensions
This is experimental.
This commit is contained in:
parent
414fa9d161
commit
6b5dcd46c5
@ -133,6 +133,7 @@ Optionally, you can use the following command-line flags:
|
||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
|
||||
| `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.|
|
||||
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
|
||||
| `--extensions EXTENSIONS` | The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this". |
|
||||
| `--listen` | Make the web UI reachable from your local network.|
|
||||
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
||||
| `--verbose` | Print the prompts to the terminal. |
|
||||
|
14
extensions/example/script.py
Normal file
14
extensions/example/script.py
Normal file
@ -0,0 +1,14 @@
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
|
||||
return string.replace(' ', '#')
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
|
||||
return string.replace(' ', '_')
|
141
server.py
141
server.py
@ -5,6 +5,7 @@ import glob
|
||||
import torch
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from sys import exit
|
||||
from pathlib import Path
|
||||
import gradio as gr
|
||||
@ -32,6 +33,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to
|
||||
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
|
||||
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
|
||||
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||
parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
|
||||
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
|
||||
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
|
||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
@ -165,6 +167,9 @@ def formatted_outputs(reply, model_name):
|
||||
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
|
||||
global model, tokenizer, model_name, loaded_preset, preset
|
||||
|
||||
original_question = question
|
||||
if not (args.chat or args.cai_chat):
|
||||
question = apply_extensions(question, "input")
|
||||
if args.verbose:
|
||||
print(f"\n\n{question}\n--------------------\n")
|
||||
|
||||
@ -203,20 +208,36 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
reply = decode(output[0])
|
||||
t1 = time.time()
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
||||
if not (args.chat or args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
yield formatted_outputs(reply, model_name)
|
||||
|
||||
# Generate the reply 1 token at a time
|
||||
else:
|
||||
yield formatted_outputs(question, model_name)
|
||||
yield formatted_outputs(original_question, model_name)
|
||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
|
||||
for i in tqdm(range(tokens//8+1)):
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||
reply = decode(output[0])
|
||||
if not (args.chat or args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
yield formatted_outputs(reply, model_name)
|
||||
input_ids = output
|
||||
if output[0][-1] == n:
|
||||
break
|
||||
|
||||
def apply_extensions(text, typ):
|
||||
global available_extensions, extension_state
|
||||
for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
|
||||
if extension_state[ext][0] == True:
|
||||
ext_string = f"extensions.{ext}.script"
|
||||
exec(f"import {ext_string}")
|
||||
if typ == "input":
|
||||
text = eval(f"{ext_string}.input_modifier(text)")
|
||||
else:
|
||||
text = eval(f"{ext_string}.output_modifier(text)")
|
||||
return text
|
||||
|
||||
def get_available_models():
|
||||
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
||||
|
||||
@ -226,9 +247,19 @@ def get_available_presets():
|
||||
def get_available_characters():
|
||||
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
|
||||
|
||||
def get_available_extensions():
|
||||
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||
|
||||
available_models = get_available_models()
|
||||
available_presets = get_available_presets()
|
||||
available_characters = get_available_characters()
|
||||
available_extensions = get_available_extensions()
|
||||
extension_state = {}
|
||||
if args.extensions is not None:
|
||||
for i,ext in enumerate(args.extensions.split(',')):
|
||||
if ext in available_extensions:
|
||||
print(f'The extension "{ext}" is enabled.')
|
||||
extension_state[ext] = [True, i]
|
||||
|
||||
# Choosing the default model
|
||||
if args.model is not None:
|
||||
@ -256,7 +287,7 @@ description = f"\n\n# Text generation lab\nGenerate text using Large Language Mo
|
||||
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
|
||||
|
||||
if args.chat or args.cai_chat:
|
||||
history = []
|
||||
history = {'internal': [], 'visible': []}
|
||||
character = None
|
||||
|
||||
# This gets the new line characters right.
|
||||
@ -270,13 +301,13 @@ if args.chat or args.cai_chat:
|
||||
text = clean_chat_message(text)
|
||||
|
||||
rows = [f"{context.strip()}\n"]
|
||||
i = len(history)-1
|
||||
i = len(history['internal'])-1
|
||||
count = 0
|
||||
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
|
||||
rows.insert(1, f"{name2}: {history[i][1].strip()}\n")
|
||||
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
|
||||
count += 1
|
||||
if not (history[i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
|
||||
rows.insert(1, f"{name1}: {history[i][0].strip()}\n")
|
||||
if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
|
||||
rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
|
||||
count += 1
|
||||
i -= 1
|
||||
if history_size != 0 and count >= history_size:
|
||||
@ -291,18 +322,12 @@ if args.chat or args.cai_chat:
|
||||
question = ''.join(rows)
|
||||
return question
|
||||
|
||||
def remove_example_dialogue_from_history(history):
|
||||
_history = copy.deepcopy(history)
|
||||
for i in range(len(_history)):
|
||||
if '<|BEGIN-VISIBLE-CHAT|>' in _history[i][0]:
|
||||
_history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
|
||||
_history = _history[i:]
|
||||
break
|
||||
return _history
|
||||
|
||||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
original_text = text
|
||||
text = apply_extensions(text, "input")
|
||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
||||
history.append(['', ''])
|
||||
history['internal'].append(['', ''])
|
||||
history['visible'].append(['', ''])
|
||||
eos_token = '\n' if check else None
|
||||
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||
next_character_found = False
|
||||
@ -312,7 +337,6 @@ if args.chat or args.cai_chat:
|
||||
idx = idx[len(previous_idx)-1]
|
||||
|
||||
reply = reply[idx + len(f"\n{name2}:"):]
|
||||
|
||||
if check:
|
||||
reply = reply.split('\n')[0].strip()
|
||||
else:
|
||||
@ -322,7 +346,8 @@ if args.chat or args.cai_chat:
|
||||
next_character_found = True
|
||||
reply = clean_chat_message(reply)
|
||||
|
||||
history[-1] = [text, reply]
|
||||
history['internal'][-1] = [text, reply]
|
||||
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
|
||||
if next_character_found:
|
||||
break
|
||||
|
||||
@ -335,16 +360,17 @@ if args.chat or args.cai_chat:
|
||||
next_character_substring_found = True
|
||||
|
||||
if not next_character_substring_found:
|
||||
yield remove_example_dialogue_from_history(history)
|
||||
yield history['visible']
|
||||
|
||||
yield remove_example_dialogue_from_history(history)
|
||||
yield history['visible']
|
||||
|
||||
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
yield generate_chat_html(history, name1, name2, character)
|
||||
for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
yield generate_chat_html(_history, name1, name2, character)
|
||||
|
||||
def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
last = history.pop()
|
||||
last = history['internal'].pop()
|
||||
history['visible'].pop()
|
||||
text = last[0]
|
||||
if args.cai_chat:
|
||||
for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
@ -354,12 +380,15 @@ if args.chat or args.cai_chat:
|
||||
yield i
|
||||
|
||||
def remove_last_message(name1, name2):
|
||||
last = history.pop()
|
||||
_history = remove_example_dialogue_from_history(history)
|
||||
if args.cai_chat:
|
||||
return generate_chat_html(_history, name1, name2, character), last[0]
|
||||
if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
|
||||
last = history['visible'].pop()
|
||||
history['internal'].pop()
|
||||
else:
|
||||
return _history, last[0]
|
||||
last = ['', '']
|
||||
if args.cai_chat:
|
||||
return generate_chat_html(history['visible'], name1, name2, character), last[0]
|
||||
else:
|
||||
return history['visible'], last[0]
|
||||
|
||||
def clear_html():
|
||||
return generate_chat_html([], "", "", character)
|
||||
@ -367,28 +396,31 @@ if args.chat or args.cai_chat:
|
||||
def clear_chat_log(_character, name1, name2):
|
||||
global history
|
||||
if _character != 'None':
|
||||
load_character(_character, name1, name2)
|
||||
for i in range(len(history['internal'])):
|
||||
if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]:
|
||||
history['visible'] = [['', history['internal'][i][1]]]
|
||||
history['internal'] = history['internal'][:i+1]
|
||||
break
|
||||
else:
|
||||
history = []
|
||||
_history = remove_example_dialogue_from_history(history)
|
||||
history['internal'] = []
|
||||
history['visible'] = []
|
||||
if args.cai_chat:
|
||||
return generate_chat_html(_history, name1, name2, character)
|
||||
return generate_chat_html(history['visible'], name1, name2, character)
|
||||
else:
|
||||
return _history
|
||||
return history['visible']
|
||||
|
||||
def redraw_html(name1, name2):
|
||||
global history
|
||||
_history = remove_example_dialogue_from_history(history)
|
||||
return generate_chat_html(_history, name1, name2, character)
|
||||
return generate_chat_html(history['visible'], name1, name2, character)
|
||||
|
||||
def tokenize_dialogue(dialogue, name1, name2):
|
||||
history = []
|
||||
_history = []
|
||||
|
||||
dialogue = re.sub('<START>', '', dialogue)
|
||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
|
||||
if len(idx) == 0:
|
||||
return history
|
||||
return _history
|
||||
|
||||
messages = []
|
||||
for i in range(len(idx)-1):
|
||||
@ -402,16 +434,16 @@ if args.chat or args.cai_chat:
|
||||
elif i.startswith(f'{name2}:'):
|
||||
entry[1] = i[len(f'{name2}:'):].strip()
|
||||
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
|
||||
history.append(entry)
|
||||
_history.append(entry)
|
||||
entry = ['', '']
|
||||
|
||||
return history
|
||||
return _history
|
||||
|
||||
def save_history():
|
||||
if not Path('logs').exists():
|
||||
Path('logs').mkdir()
|
||||
with open(Path('logs/conversation.json'), 'w') as f:
|
||||
f.write(json.dumps({'data': history}, indent=2))
|
||||
f.write(json.dumps({'data': history['internal']}, indent=2))
|
||||
return Path('logs/conversation.json')
|
||||
|
||||
def upload_history(file, name1, name2):
|
||||
@ -420,21 +452,22 @@ if args.chat or args.cai_chat:
|
||||
try:
|
||||
j = json.loads(file)
|
||||
if 'data' in j:
|
||||
history = j['data']
|
||||
history['internal'] = j['data']
|
||||
# Compatibility with Pygmalion AI's official web UI
|
||||
elif 'chat' in j:
|
||||
history = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
||||
history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
||||
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
||||
history = [['<|BEGIN-VISIBLE-CHAT|>', history[0]]] + [[history[i], history[i+1]] for i in range(1, len(history)-1, 2)]
|
||||
history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)]
|
||||
else:
|
||||
history = [[history[i], history[i+1]] for i in range(0, len(history)-1, 2)]
|
||||
history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)]
|
||||
except:
|
||||
history = tokenize_dialogue(file, name1, name2)
|
||||
history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||
|
||||
def load_character(_character, name1, name2):
|
||||
global history, character
|
||||
context = ""
|
||||
history = []
|
||||
history['internal'] = []
|
||||
history['visible'] = []
|
||||
if _character != 'None':
|
||||
character = _character
|
||||
with open(Path(f'characters/{_character}.json'), 'r') as f:
|
||||
@ -446,24 +479,24 @@ if args.chat or args.cai_chat:
|
||||
context += f"Scenario: {data['world_scenario']}\n"
|
||||
context = f"{context.strip()}\n<START>\n"
|
||||
if 'example_dialogue' in data and data['example_dialogue'] != '':
|
||||
history = tokenize_dialogue(data['example_dialogue'], name1, name2)
|
||||
history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
|
||||
if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
|
||||
history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
|
||||
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
|
||||
history['visible'] += [['', data['char_greeting']]]
|
||||
else:
|
||||
history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
|
||||
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
|
||||
history['visible'] += [['', "Hello there!"]]
|
||||
else:
|
||||
character = None
|
||||
context = settings['context_pygmalion']
|
||||
name2 = settings['name2_pygmalion']
|
||||
|
||||
_history = remove_example_dialogue_from_history(history)
|
||||
if args.cai_chat:
|
||||
return name2, context, generate_chat_html(_history, name1, name2, character)
|
||||
return name2, context, generate_chat_html(history['visible'], name1, name2, character)
|
||||
else:
|
||||
return name2, context, _history
|
||||
return name2, context, history['visible']
|
||||
|
||||
def upload_character(file, name1, name2):
|
||||
global history
|
||||
file = file.decode('utf-8')
|
||||
data = json.loads(file)
|
||||
outfile_name = data["char_name"]
|
||||
@ -543,7 +576,7 @@ if args.chat or args.cai_chat:
|
||||
if args.cai_chat:
|
||||
upload.upload(redraw_html, [name1, name2], [display1])
|
||||
else:
|
||||
upload.upload(lambda : remove_example_dialogue_from_history(history), [], [display1])
|
||||
upload.upload(lambda : history['visible'], [], [display1])
|
||||
|
||||
elif args.notebook:
|
||||
with gr.Blocks(css=css, analytics_enabled=False) as interface:
|
||||
|
Loading…
Reference in New Issue
Block a user