Add support for TavernAI character cards (closes #31)

This commit is contained in:
oobabooga 2023-01-28 20:18:23 -03:00
parent f71531186b
commit 3687962e6c

View File

@ -6,6 +6,7 @@ import torch
import argparse import argparse
import json import json
import io import io
import base64
import sys import sys
from sys import exit from sys import exit
from pathlib import Path from pathlib import Path
@ -422,7 +423,9 @@ if args.chat or args.cai_chat:
_history = [] _history = []
dialogue = re.sub('<START>', '', dialogue) dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('<start>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\1{name2}:', dialogue)
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)] idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
if len(idx) == 0: if len(idx) == 0:
return _history return _history
@ -442,6 +445,11 @@ if args.chat or args.cai_chat:
_history.append(entry) _history.append(entry)
entry = ['', ''] entry = ['', '']
print(f"\nDialogue tokenized to:\n\n", end='')
for i in _history:
print(i)
print("--------------------\n", end='')
return _history return _history
def save_history(): def save_history():
@ -506,8 +514,8 @@ if args.chat or args.cai_chat:
else: else:
return name2, context, history['visible'] return name2, context, history['visible']
def upload_character(json_file, img, name1, name2): def upload_character(json_file, img):
json_file = json_file.decode('utf-8') json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
data = json.loads(json_file) data = json.loads(json_file)
outfile_name = data["char_name"] outfile_name = data["char_name"]
i = 1 i = 1
@ -517,15 +525,24 @@ if args.chat or args.cai_chat:
with open(Path(f'characters/{outfile_name}.json'), 'w') as f: with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
f.write(json_file) f.write(json_file)
if img is not None: if img is not None:
img = Image.open(io.BytesIO(img)).convert('RGB') img = Image.open(io.BytesIO(img))
img.save(Path(f'characters/{outfile_name}.jpg')) img.save(Path(f'characters/{outfile_name}.png'))
print(f'New character saved to "characters/{outfile_name}.json".') print(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name return outfile_name
def upload_tavern_character(img, name1, name2):
_img = Image.open(io.BytesIO(img))
_img.getexif()
decoded_string = base64.b64decode(_img.info['chara'])
_json = json.loads(decoded_string)
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
_json['example_dialogue'] = _json['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', _json['char_name'])
return upload_character(json.dumps(_json), img)
def upload_your_profile_picture(img): def upload_your_profile_picture(img):
img = Image.open(io.BytesIO(img)).convert('RGB') img = Image.open(io.BytesIO(img))
img.save(Path(f'img_me.jpg')) img.save(Path(f'img_me.png'))
print(f'Profile picture saved to "img_me.jpg"') print(f'Profile picture saved to "img_me.png"')
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else '' suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface: with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface:
@ -579,6 +596,8 @@ if args.chat or args.cai_chat:
upload_btn = gr.Button(value="Submit") upload_btn = gr.Button(value="Submit")
with gr.Tab('Upload your profile picture'): with gr.Tab('Upload your profile picture'):
upload_img_me = gr.File(type='binary') upload_img_me = gr.File(type='binary')
with gr.Tab('Upload TavernAI Character Card'):
upload_img_tavern = gr.File(type='binary')
input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider] input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider]
if args.cai_chat: if args.cai_chat:
@ -598,7 +617,8 @@ if args.chat or args.cai_chat:
save_btn.click(save_history, inputs=[], outputs=[download]) save_btn.click(save_history, inputs=[], outputs=[download])
character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1]) character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1])
upload.upload(upload_history, [upload, name1, name2], []) upload.upload(upload_history, [upload, name1, name2], [])
upload_btn.click(upload_character, [upload_char, upload_img, name1, name2], [character_menu]) upload_btn.click(upload_character, [upload_char, upload_img], [character_menu])
upload_img_tavern.upload(upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu])
upload_img_me.upload(upload_your_profile_picture, [upload_img_me], []) upload_img_me.upload(upload_your_profile_picture, [upload_img_me], [])
if args.cai_chat: if args.cai_chat: