mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 17:50:22 +01:00
Add support for TavernAI character cards (closes #31)
This commit is contained in:
parent
f71531186b
commit
3687962e6c
36
server.py
36
server.py
@ -6,6 +6,7 @@ import torch
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import io
|
import io
|
||||||
|
import base64
|
||||||
import sys
|
import sys
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -422,7 +423,9 @@ if args.chat or args.cai_chat:
|
|||||||
_history = []
|
_history = []
|
||||||
|
|
||||||
dialogue = re.sub('<START>', '', dialogue)
|
dialogue = re.sub('<START>', '', dialogue)
|
||||||
|
dialogue = re.sub('<start>', '', dialogue)
|
||||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||||
|
dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\1{name2}:', dialogue)
|
||||||
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
|
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
|
||||||
if len(idx) == 0:
|
if len(idx) == 0:
|
||||||
return _history
|
return _history
|
||||||
@ -442,6 +445,11 @@ if args.chat or args.cai_chat:
|
|||||||
_history.append(entry)
|
_history.append(entry)
|
||||||
entry = ['', '']
|
entry = ['', '']
|
||||||
|
|
||||||
|
print(f"\nDialogue tokenized to:\n\n", end='')
|
||||||
|
for i in _history:
|
||||||
|
print(i)
|
||||||
|
print("--------------------\n", end='')
|
||||||
|
|
||||||
return _history
|
return _history
|
||||||
|
|
||||||
def save_history():
|
def save_history():
|
||||||
@ -506,8 +514,8 @@ if args.chat or args.cai_chat:
|
|||||||
else:
|
else:
|
||||||
return name2, context, history['visible']
|
return name2, context, history['visible']
|
||||||
|
|
||||||
def upload_character(json_file, img, name1, name2):
|
def upload_character(json_file, img):
|
||||||
json_file = json_file.decode('utf-8')
|
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||||
data = json.loads(json_file)
|
data = json.loads(json_file)
|
||||||
outfile_name = data["char_name"]
|
outfile_name = data["char_name"]
|
||||||
i = 1
|
i = 1
|
||||||
@ -517,15 +525,24 @@ if args.chat or args.cai_chat:
|
|||||||
with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
|
with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
|
||||||
f.write(json_file)
|
f.write(json_file)
|
||||||
if img is not None:
|
if img is not None:
|
||||||
img = Image.open(io.BytesIO(img)).convert('RGB')
|
img = Image.open(io.BytesIO(img))
|
||||||
img.save(Path(f'characters/{outfile_name}.jpg'))
|
img.save(Path(f'characters/{outfile_name}.png'))
|
||||||
print(f'New character saved to "characters/{outfile_name}.json".')
|
print(f'New character saved to "characters/{outfile_name}.json".')
|
||||||
return outfile_name
|
return outfile_name
|
||||||
|
|
||||||
|
def upload_tavern_character(img, name1, name2):
|
||||||
|
_img = Image.open(io.BytesIO(img))
|
||||||
|
_img.getexif()
|
||||||
|
decoded_string = base64.b64decode(_img.info['chara'])
|
||||||
|
_json = json.loads(decoded_string)
|
||||||
|
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
||||||
|
_json['example_dialogue'] = _json['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', _json['char_name'])
|
||||||
|
return upload_character(json.dumps(_json), img)
|
||||||
|
|
||||||
def upload_your_profile_picture(img):
|
def upload_your_profile_picture(img):
|
||||||
img = Image.open(io.BytesIO(img)).convert('RGB')
|
img = Image.open(io.BytesIO(img))
|
||||||
img.save(Path(f'img_me.jpg'))
|
img.save(Path(f'img_me.png'))
|
||||||
print(f'Profile picture saved to "img_me.jpg"')
|
print(f'Profile picture saved to "img_me.png"')
|
||||||
|
|
||||||
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
|
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
|
||||||
with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface:
|
with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface:
|
||||||
@ -579,6 +596,8 @@ if args.chat or args.cai_chat:
|
|||||||
upload_btn = gr.Button(value="Submit")
|
upload_btn = gr.Button(value="Submit")
|
||||||
with gr.Tab('Upload your profile picture'):
|
with gr.Tab('Upload your profile picture'):
|
||||||
upload_img_me = gr.File(type='binary')
|
upload_img_me = gr.File(type='binary')
|
||||||
|
with gr.Tab('Upload TavernAI Character Card'):
|
||||||
|
upload_img_tavern = gr.File(type='binary')
|
||||||
|
|
||||||
input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider]
|
input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider]
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
@ -598,7 +617,8 @@ if args.chat or args.cai_chat:
|
|||||||
save_btn.click(save_history, inputs=[], outputs=[download])
|
save_btn.click(save_history, inputs=[], outputs=[download])
|
||||||
character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1])
|
character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1])
|
||||||
upload.upload(upload_history, [upload, name1, name2], [])
|
upload.upload(upload_history, [upload, name1, name2], [])
|
||||||
upload_btn.click(upload_character, [upload_char, upload_img, name1, name2], [character_menu])
|
upload_btn.click(upload_character, [upload_char, upload_img], [character_menu])
|
||||||
|
upload_img_tavern.upload(upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu])
|
||||||
upload_img_me.upload(upload_your_profile_picture, [upload_img_me], [])
|
upload_img_me.upload(upload_your_profile_picture, [upload_img_me], [])
|
||||||
|
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
|
Loading…
Reference in New Issue
Block a user