2024-02-19 09:23:37 +01:00
# include <iostream>
# include <string>
# include <vector>
# include <sstream>
# undef NDEBUG
# include <cassert>
# include "llama.h"
int main ( void ) {
llama_chat_message conversation [ ] = {
{ " system " , " You are a helpful assistant " } ,
{ " user " , " Hello " } ,
{ " assistant " , " Hi there " } ,
{ " user " , " Who are you " } ,
{ " assistant " , " I am an assistant " } ,
{ " user " , " Another question " } ,
} ;
size_t message_count = 6 ;
std : : vector < std : : string > templates = {
// teknium/OpenHermes-2.5-Mistral-7B
" {% for message in messages %}{{'<|im_start|>' + message['role'] + ' \\ n' + message['content'] + '<|im_end|>' + ' \\ n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant \\ n' }}{% endif %} " ,
// mistralai/Mistral-7B-Instruct-v0.2
" {{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %} " ,
// TheBloke/FusionNet_34Bx2_MoE-AWQ
" {%- for idx in range(0, messages|length) -%} \\ n{%- if messages[idx]['role'] == 'user' -%} \\ n{%- if idx > 1 -%} \\ n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}} \\ n{%- else -%} \\ n{{- messages[idx]['content'] + ' [/INST]' -}} \\ n{%- endif -%} \\ n{% elif messages[idx]['role'] == 'system' %} \\ n{{- '[INST] <<SYS>> \\ \\ n' + messages[idx]['content'] + ' \\ \\ n<</SYS>> \\ \\ n \\ \\ n' -}} \\ n{%- elif messages[idx]['role'] == 'assistant' -%} \\ n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}} \\ n{% endif %} \\ n{% endfor %} " ,
// bofenghuang/vigogne-2-70b-chat
" {{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>> \\ \\ n' + system_message + ' \\ \\ n<</SYS>> \\ \\ n \\ \\ n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>> \\ \\ n' + content.strip() + ' \\ \\ n<</SYS>> \\ \\ n \\ \\ n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %} " ,
2024-02-22 09:33:24 +01:00
// mlabonne/AlphaMonarch-7B
" {% for message in messages %}{{bos_token + message['role'] + ' \\ n' + message['content'] + eos_token + ' \\ n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant \\ n' }}{% endif %} " ,
2024-02-19 09:23:37 +01:00
} ;
2024-02-22 09:33:24 +01:00
std : : vector < std : : string > expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B
" <|im_start|>system \n You are a helpful assistant<|im_end|> \n <|im_start|>user \n Hello<|im_end|> \n <|im_start|>assistant \n Hi there<|im_end|> \n <|im_start|>user \n Who are you<|im_end|> \n <|im_start|>assistant \n I am an assistant <|im_end|> \n <|im_start|>user \n Another question<|im_end|> \n <|im_start|>assistant \n " ,
// mistralai/Mistral-7B-Instruct-v0.2
" [INST] You are a helpful assistant \n Hello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST] " ,
// TheBloke/FusionNet_34Bx2_MoE-AWQ
" [INST] <<SYS>> \n You are a helpful assistant \n <</SYS>> \n \n Hello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST] " ,
// bofenghuang/vigogne-2-70b-chat
" [INST] <<SYS>> \n You are a helpful assistant \n <</SYS>> \n \n Hello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST] " ,
// mlabonne/AlphaMonarch-7B
" system \n You are a helpful assistant</s> \n <s>user \n Hello</s> \n <s>assistant \n Hi there</s> \n <s>user \n Who are you</s> \n <s>assistant \n I am an assistant </s> \n <s>user \n Another question</s> \n <s>assistant \n " ,
2024-02-19 09:23:37 +01:00
} ;
std : : vector < char > formatted_chat ( 1024 ) ;
int32_t res ;
// test invalid chat template
res = llama_chat_apply_template ( nullptr , " INVALID TEMPLATE " , conversation , message_count , true , formatted_chat . data ( ) , formatted_chat . size ( ) ) ;
assert ( res < 0 ) ;
for ( size_t i = 0 ; i < templates . size ( ) ; i + + ) {
std : : string custom_template = templates [ i ] ;
2024-02-22 09:33:24 +01:00
std : : string expected = expected_output [ i ] ;
2024-02-19 09:23:37 +01:00
formatted_chat . resize ( 1024 ) ;
res = llama_chat_apply_template (
nullptr ,
custom_template . c_str ( ) ,
conversation ,
message_count ,
true ,
formatted_chat . data ( ) ,
formatted_chat . size ( )
) ;
formatted_chat . resize ( res ) ;
std : : string output ( formatted_chat . data ( ) , formatted_chat . size ( ) ) ;
std : : cout < < output < < " \n ------------------------- \n " ;
2024-02-22 09:33:24 +01:00
assert ( output = = expected ) ;
2024-02-19 09:23:37 +01:00
}
return 0 ;
}