From 3687962e6c4d4ce1c6f3c43f6e31cba02353859d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 28 Jan 2023 20:18:23 -0300 Subject: [PATCH] Add support for TavernAI character cards (closes #31) --- server.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/server.py b/server.py index 6050f740..dbc78d6e 100644 --- a/server.py +++ b/server.py @@ -6,6 +6,7 @@ import torch import argparse import json import io +import base64 import sys from sys import exit from pathlib import Path @@ -422,7 +423,9 @@ if args.chat or args.cai_chat: _history = [] dialogue = re.sub('', '', dialogue) + dialogue = re.sub('', '', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) + dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\1{name2}:', dialogue) idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)] if len(idx) == 0: return _history @@ -442,6 +445,11 @@ if args.chat or args.cai_chat: _history.append(entry) entry = ['', ''] + print(f"\nDialogue tokenized to:\n\n", end='') + for i in _history: + print(i) + print("--------------------\n", end='') + return _history def save_history(): @@ -506,8 +514,8 @@ if args.chat or args.cai_chat: else: return name2, context, history['visible'] - def upload_character(json_file, img, name1, name2): - json_file = json_file.decode('utf-8') + def upload_character(json_file, img): + json_file = json_file if type(json_file) == str else json_file.decode('utf-8') data = json.loads(json_file) outfile_name = data["char_name"] i = 1 @@ -517,15 +525,24 @@ if args.chat or args.cai_chat: with open(Path(f'characters/{outfile_name}.json'), 'w') as f: f.write(json_file) if img is not None: - img = Image.open(io.BytesIO(img)).convert('RGB') - img.save(Path(f'characters/{outfile_name}.jpg')) + img = Image.open(io.BytesIO(img)) + img.save(Path(f'characters/{outfile_name}.png')) print(f'New character saved to "characters/{outfile_name}.json".') return outfile_name + def upload_tavern_character(img, name1, name2): + _img = Image.open(io.BytesIO(img)) + _img.getexif() + decoded_string = base64.b64decode(_img.info['chara']) + _json = json.loads(decoded_string) + _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} + _json['example_dialogue'] = _json['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', _json['char_name']) + return upload_character(json.dumps(_json), img) + def upload_your_profile_picture(img): - img = Image.open(io.BytesIO(img)).convert('RGB') - img.save(Path(f'img_me.jpg')) - print(f'Profile picture saved to "img_me.jpg"') + img = Image.open(io.BytesIO(img)) + img.save(Path(f'img_me.png')) + print(f'Profile picture saved to "img_me.png"') suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else '' with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface: @@ -579,6 +596,8 @@ if args.chat or args.cai_chat: upload_btn = gr.Button(value="Submit") with gr.Tab('Upload your profile picture'): upload_img_me = gr.File(type='binary') + with gr.Tab('Upload TavernAI Character Card'): + upload_img_tavern = gr.File(type='binary') input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider] if args.cai_chat: @@ -598,7 +617,8 @@ if args.chat or args.cai_chat: save_btn.click(save_history, inputs=[], outputs=[download]) character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1]) upload.upload(upload_history, [upload, name1, name2], []) - upload_btn.click(upload_character, [upload_char, upload_img, name1, name2], [character_menu]) + upload_btn.click(upload_character, [upload_char, upload_img], [character_menu]) + upload_img_tavern.upload(upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu]) upload_img_me.upload(upload_your_profile_picture, [upload_img_me], []) if args.cai_chat: