diff --git a/README.md b/README.md index 1fb8227d..041a6b04 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ Optionally, you can use the following command-line flags: | `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.| | `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.| +| `--extensions EXTENSIONS` | The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this". | | `--listen` | Make the web UI reachable from your local network.| | `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | | `--verbose` | Print the prompts to the terminal. | diff --git a/extensions/example/script.py b/extensions/example/script.py new file mode 100644 index 00000000..4314af2d --- /dev/null +++ b/extensions/example/script.py @@ -0,0 +1,14 @@ +def input_modifier(string): + """ + This function is applied to your text inputs before + they are fed into the model. + """ + + return string.replace(' ', '#') + +def output_modifier(string): + """ + This function is applied to the model outputs. + """ + + return string.replace(' ', '_') diff --git a/server.py b/server.py index 0d1483db..b51425db 100644 --- a/server.py +++ b/server.py @@ -5,6 +5,7 @@ import glob import torch import argparse import json +import sys from sys import exit from pathlib import Path import gradio as gr @@ -32,6 +33,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') +parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') @@ -165,6 +167,9 @@ def formatted_outputs(reply, model_name): def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None): global model, tokenizer, model_name, loaded_preset, preset + original_question = question + if not (args.chat or args.cai_chat): + question = apply_extensions(question, "input") if args.verbose: print(f"\n\n{question}\n--------------------\n") @@ -203,20 +208,36 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok reply = decode(output[0]) t1 = time.time() print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)") + if not (args.chat or args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") yield formatted_outputs(reply, model_name) # Generate the reply 1 token at a time else: - yield formatted_outputs(question, model_name) + yield formatted_outputs(original_question, model_name) preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8') for i in tqdm(range(tokens//8+1)): output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") reply = decode(output[0]) + if not (args.chat or args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") yield formatted_outputs(reply, model_name) input_ids = output if output[0][-1] == n: break +def apply_extensions(text, typ): + global available_extensions, extension_state + for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): + if extension_state[ext][0] == True: + ext_string = f"extensions.{ext}.script" + exec(f"import {ext_string}") + if typ == "input": + text = eval(f"{ext_string}.input_modifier(text)") + else: + text = eval(f"{ext_string}.output_modifier(text)") + return text + def get_available_models(): return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower) @@ -226,9 +247,19 @@ def get_available_presets(): def get_available_characters(): return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) +def get_available_extensions(): + return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) + available_models = get_available_models() available_presets = get_available_presets() available_characters = get_available_characters() +available_extensions = get_available_extensions() +extension_state = {} +if args.extensions is not None: + for i,ext in enumerate(args.extensions.split(',')): + if ext in available_extensions: + print(f'The extension "{ext}" is enabled.') + extension_state[ext] = [True, i] # Choosing the default model if args.model is not None: @@ -256,7 +287,7 @@ description = f"\n\n# Text generation lab\nGenerate text using Large Language Mo css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}" if args.chat or args.cai_chat: - history = [] + history = {'internal': [], 'visible': []} character = None # This gets the new line characters right. @@ -270,13 +301,13 @@ if args.chat or args.cai_chat: text = clean_chat_message(text) rows = [f"{context.strip()}\n"] - i = len(history)-1 + i = len(history['internal'])-1 count = 0 while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens: - rows.insert(1, f"{name2}: {history[i][1].strip()}\n") + rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n") count += 1 - if not (history[i][0] == '<|BEGIN-VISIBLE-CHAT|>'): - rows.insert(1, f"{name1}: {history[i][0].strip()}\n") + if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): + rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n") count += 1 i -= 1 if history_size != 0 and count >= history_size: @@ -291,18 +322,12 @@ if args.chat or args.cai_chat: question = ''.join(rows) return question - def remove_example_dialogue_from_history(history): - _history = copy.deepcopy(history) - for i in range(len(_history)): - if '<|BEGIN-VISIBLE-CHAT|>' in _history[i][0]: - _history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '') - _history = _history[i:] - break - return _history - def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): + original_text = text + text = apply_extensions(text, "input") question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) - history.append(['', '']) + history['internal'].append(['', '']) + history['visible'].append(['', '']) eos_token = '\n' if check else None for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"): next_character_found = False @@ -312,7 +337,6 @@ if args.chat or args.cai_chat: idx = idx[len(previous_idx)-1] reply = reply[idx + len(f"\n{name2}:"):] - if check: reply = reply.split('\n')[0].strip() else: @@ -322,7 +346,8 @@ if args.chat or args.cai_chat: next_character_found = True reply = clean_chat_message(reply) - history[-1] = [text, reply] + history['internal'][-1] = [text, reply] + history['visible'][-1] = [original_text, apply_extensions(reply, "output")] if next_character_found: break @@ -335,16 +360,17 @@ if args.chat or args.cai_chat: next_character_substring_found = True if not next_character_substring_found: - yield remove_example_dialogue_from_history(history) + yield history['visible'] - yield remove_example_dialogue_from_history(history) + yield history['visible'] def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): - for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): - yield generate_chat_html(history, name1, name2, character) + for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): + yield generate_chat_html(_history, name1, name2, character) def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): - last = history.pop() + last = history['internal'].pop() + history['visible'].pop() text = last[0] if args.cai_chat: for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): @@ -354,12 +380,15 @@ if args.chat or args.cai_chat: yield i def remove_last_message(name1, name2): - last = history.pop() - _history = remove_example_dialogue_from_history(history) - if args.cai_chat: - return generate_chat_html(_history, name1, name2, character), last[0] + if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': + last = history['visible'].pop() + history['internal'].pop() else: - return _history, last[0] + last = ['', ''] + if args.cai_chat: + return generate_chat_html(history['visible'], name1, name2, character), last[0] + else: + return history['visible'], last[0] def clear_html(): return generate_chat_html([], "", "", character) @@ -367,28 +396,31 @@ if args.chat or args.cai_chat: def clear_chat_log(_character, name1, name2): global history if _character != 'None': - load_character(_character, name1, name2) + for i in range(len(history['internal'])): + if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]: + history['visible'] = [['', history['internal'][i][1]]] + history['internal'] = history['internal'][:i+1] + break else: - history = [] - _history = remove_example_dialogue_from_history(history) + history['internal'] = [] + history['visible'] = [] if args.cai_chat: - return generate_chat_html(_history, name1, name2, character) + return generate_chat_html(history['visible'], name1, name2, character) else: - return _history + return history['visible'] def redraw_html(name1, name2): global history - _history = remove_example_dialogue_from_history(history) - return generate_chat_html(_history, name1, name2, character) + return generate_chat_html(history['visible'], name1, name2, character) def tokenize_dialogue(dialogue, name1, name2): - history = [] + _history = [] dialogue = re.sub('', '', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)] if len(idx) == 0: - return history + return _history messages = [] for i in range(len(idx)-1): @@ -402,16 +434,16 @@ if args.chat or args.cai_chat: elif i.startswith(f'{name2}:'): entry[1] = i[len(f'{name2}:'):].strip() if not (len(entry[0]) == 0 and len(entry[1]) == 0): - history.append(entry) + _history.append(entry) entry = ['', ''] - return history + return _history def save_history(): if not Path('logs').exists(): Path('logs').mkdir() with open(Path('logs/conversation.json'), 'w') as f: - f.write(json.dumps({'data': history}, indent=2)) + f.write(json.dumps({'data': history['internal']}, indent=2)) return Path('logs/conversation.json') def upload_history(file, name1, name2): @@ -420,21 +452,22 @@ if args.chat or args.cai_chat: try: j = json.loads(file) if 'data' in j: - history = j['data'] + history['internal'] = j['data'] # Compatibility with Pygmalion AI's official web UI elif 'chat' in j: - history = [':'.join(x.split(':')[1:]).strip() for x in j['chat']] + history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']] if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'): - history = [['<|BEGIN-VISIBLE-CHAT|>', history[0]]] + [[history[i], history[i+1]] for i in range(1, len(history)-1, 2)] + history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)] else: - history = [[history[i], history[i+1]] for i in range(0, len(history)-1, 2)] + history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)] except: - history = tokenize_dialogue(file, name1, name2) + history['internal'] = tokenize_dialogue(file, name1, name2) def load_character(_character, name1, name2): global history, character context = "" - history = [] + history['internal'] = [] + history['visible'] = [] if _character != 'None': character = _character with open(Path(f'characters/{_character}.json'), 'r') as f: @@ -446,24 +479,24 @@ if args.chat or args.cai_chat: context += f"Scenario: {data['world_scenario']}\n" context = f"{context.strip()}\n\n" if 'example_dialogue' in data and data['example_dialogue'] != '': - history = tokenize_dialogue(data['example_dialogue'], name1, name2) + history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2) if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0: - history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]] + history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]] + history['visible'] += [['', data['char_greeting']]] else: - history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]] + history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]] + history['visible'] += [['', "Hello there!"]] else: character = None context = settings['context_pygmalion'] name2 = settings['name2_pygmalion'] - _history = remove_example_dialogue_from_history(history) if args.cai_chat: - return name2, context, generate_chat_html(_history, name1, name2, character) + return name2, context, generate_chat_html(history['visible'], name1, name2, character) else: - return name2, context, _history + return name2, context, history['visible'] def upload_character(file, name1, name2): - global history file = file.decode('utf-8') data = json.loads(file) outfile_name = data["char_name"] @@ -543,7 +576,7 @@ if args.chat or args.cai_chat: if args.cai_chat: upload.upload(redraw_html, [name1, name2], [display1]) else: - upload.upload(lambda : remove_example_dialogue_from_history(history), [], [display1]) + upload.upload(lambda : history['visible'], [], [display1]) elif args.notebook: with gr.Blocks(css=css, analytics_enabled=False) as interface: