feat: save chat template with instruction template

This commit is contained in:
A0nameless0man 2024-04-21 16:10:59 +00:00
parent 0877741b03
commit 55306aa4e1
2 changed files with 11 additions and 11 deletions

View File

@ -704,22 +704,22 @@ def load_character(character, name1, name2):
return name1, name2, picture, greeting, context return name1, name2, picture, greeting, context
def load_instruction_template(template): def load_instruction_template(template, current_instruction_template=None, current_chat_template=None):
if template == 'None': if template == 'None':
return '' return '', current_chat_template
for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]: for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]:
if filepath.exists(): if filepath.exists():
break break
else: else:
return '' return '', current_chat_template
file_contents = open(filepath, 'r', encoding='utf-8').read() file_contents = open(filepath, 'r', encoding='utf-8').read()
data = yaml.safe_load(file_contents) data = yaml.safe_load(file_contents)
if 'instruction_template' in data: if 'instruction_template' in data:
return data['instruction_template'] return data['instruction_template'], data['chat_template'] if 'chat_template' in data else current_chat_template
else: else:
return jinja_template_from_old_format(data) return jinja_template_from_old_format(data), current_chat_template
@functools.cache @functools.cache
@ -821,9 +821,10 @@ def generate_character_yaml(name, greeting, context):
return yaml.dump(data, sort_keys=False, width=float("inf")) return yaml.dump(data, sort_keys=False, width=float("inf"))
def generate_instruction_template_yaml(instruction_template): def generate_instruction_template_yaml(instruction_template, chat_template):
data = { data = {
'instruction_template': instruction_template 'instruction_template': instruction_template,
'chat_template': chat_template
} }
return my_yaml_output(data) return my_yaml_output(data)

View File

@ -316,13 +316,12 @@ def create_event_handlers():
shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter')) shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'))
shared.gradio['load_template'].click( shared.gradio['load_template'].click(
chat.load_instruction_template, gradio('instruction_template'), gradio('instruction_template_str')).then( chat.load_instruction_template, gradio(['instruction_template', 'instruction_template_str', 'chat_template_str']), gradio(['instruction_template_str', 'chat_template_str']))
lambda: "Select template to load...", None, gradio('instruction_template'))
shared.gradio['save_template'].click( shared.gradio['save_template'].click(
lambda: 'My Template.yaml', None, gradio('save_filename')).then( lambda x: x + '.yaml', gradio('instruction_template'), gradio('save_filename')).then(
lambda: 'instruction-templates/', None, gradio('save_root')).then( lambda: 'instruction-templates/', None, gradio('save_root')).then(
chat.generate_instruction_template_yaml, gradio('instruction_template_str'), gradio('save_contents')).then( chat.generate_instruction_template_yaml, gradio(['instruction_template_str', 'chat_template_str']), gradio('save_contents')).then(
lambda: gr.update(visible=True), None, gradio('file_saver')) lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['delete_template'].click( shared.gradio['delete_template'].click(