mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
Add MiniCPM, Deepseek V2 chat template + clean up llama_chat_apply_template_internal
(#8172)
* tmp_contains * minicpm chat template * add DeepSeek Lite template * change deepseek-lite to deepseek2 * correct code comment * correct code from master branch
This commit is contained in:
parent
38373cfbab
commit
26a39bbd6b
@ -19613,7 +19613,10 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
std::string & dest, bool add_ass) {
|
std::string & dest, bool add_ass) {
|
||||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
|
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
|
||||||
|
return tmpl.find(haystack) != std::string::npos;
|
||||||
|
};
|
||||||
|
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
|
||||||
// chatml template
|
// chatml template
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
||||||
@ -19621,16 +19624,16 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|im_start|>assistant\n";
|
ss << "<|im_start|>assistant\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
|
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
|
||||||
// llama2 template and its variants
|
// llama2 template and its variants
|
||||||
// [variant] support system message
|
// [variant] support system message
|
||||||
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
|
bool support_system_message = tmpl_contains("<<SYS>>") || tmpl == "mistral";
|
||||||
// [variant] space before + after response
|
// [variant] space before + after response
|
||||||
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
|
bool space_around_response = tmpl_contains("' ' + eos_token");
|
||||||
// [variant] add BOS inside history
|
// [variant] add BOS inside history
|
||||||
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
|
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
|
||||||
// [variant] trim spaces from the input message
|
// [variant] trim spaces from the input message
|
||||||
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
|
bool strip_message = tmpl_contains("content.strip()");
|
||||||
// construct the prompt
|
// construct the prompt
|
||||||
bool is_inside_turn = true; // skip BOS at the beginning
|
bool is_inside_turn = true; // skip BOS at the beginning
|
||||||
ss << "[INST] ";
|
ss << "[INST] ";
|
||||||
@ -19656,7 +19659,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// llama2 templates seem to not care about "add_generation_prompt"
|
// llama2 templates seem to not care about "add_generation_prompt"
|
||||||
} else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) {
|
} else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
|
||||||
// Phi 3
|
// Phi 3
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
@ -19665,7 +19668,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|assistant|>\n";
|
ss << "<|assistant|>\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
|
} else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
|
||||||
// zephyr template
|
// zephyr template
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
||||||
@ -19673,7 +19676,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|assistant|>\n";
|
ss << "<|assistant|>\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
|
} else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
|
||||||
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
|
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
|
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
|
||||||
@ -19682,7 +19685,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<s>assistant\n";
|
ss << "<s>assistant\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("<start_of_turn>") != std::string::npos) {
|
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("<start_of_turn>")) {
|
||||||
// google/gemma-7b-it
|
// google/gemma-7b-it
|
||||||
std::string system_prompt = "";
|
std::string system_prompt = "";
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
@ -19704,7 +19707,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<start_of_turn>model\n";
|
ss << "<start_of_turn>model\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
|
} else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
|
||||||
// OrionStarAI/Orion-14B-Chat
|
// OrionStarAI/Orion-14B-Chat
|
||||||
std::string system_prompt = "";
|
std::string system_prompt = "";
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
@ -19724,7 +19727,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
ss << message->content << "</s>";
|
ss << message->content << "</s>";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
|
} else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
|
||||||
// openchat/openchat-3.5-0106,
|
// openchat/openchat-3.5-0106,
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
@ -19738,13 +19741,13 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "GPT4 Correct Assistant:";
|
ss << "GPT4 Correct Assistant:";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
|
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
|
||||||
// eachadea/vicuna-13b-1.1 (and Orca variant)
|
// eachadea/vicuna-13b-1.1 (and Orca variant)
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
if (role == "system") {
|
if (role == "system") {
|
||||||
// Orca-Vicuna variant uses a system prefix
|
// Orca-Vicuna variant uses a system prefix
|
||||||
if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
|
if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
|
||||||
ss << "SYSTEM: " << message->content << "\n";
|
ss << "SYSTEM: " << message->content << "\n";
|
||||||
} else {
|
} else {
|
||||||
ss << message->content << "\n\n";
|
ss << message->content << "\n\n";
|
||||||
@ -19758,7 +19761,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "ASSISTANT:";
|
ss << "ASSISTANT:";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
|
} else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
|
||||||
// deepseek-ai/deepseek-coder-33b-instruct
|
// deepseek-ai/deepseek-coder-33b-instruct
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
@ -19773,7 +19776,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "### Response:\n";
|
ss << "### Response:\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) {
|
} else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
|
||||||
// CohereForAI/c4ai-command-r-plus
|
// CohereForAI/c4ai-command-r-plus
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
@ -19788,7 +19791,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
|
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
|
||||||
}
|
}
|
||||||
} else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
|
} else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
|
||||||
// Llama 3
|
// Llama 3
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
@ -19797,6 +19800,33 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
|
||||||
|
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "user") {
|
||||||
|
ss << u8"<用户>";
|
||||||
|
ss << trim(message->content);
|
||||||
|
ss << "<AI>";
|
||||||
|
} else {
|
||||||
|
ss << trim(message->content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
|
||||||
|
// DeepSeek-V2
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "system") {
|
||||||
|
ss << message->content << "\n\n";
|
||||||
|
} else if (role == "user") {
|
||||||
|
ss << "User: " << message->content << "\n\n";
|
||||||
|
} else if (role == "assistant") {
|
||||||
|
ss << "Assistant: " << message->content << u8"<|end▁of▁sentence|>";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "Assistant:";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -57,7 +57,11 @@ int main(void) {
|
|||||||
//Phi-3-medium
|
//Phi-3-medium
|
||||||
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||||
//Phi-3-vision
|
//Phi-3-vision
|
||||||
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}"
|
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
|
||||||
|
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||||
|
u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
|
||||||
|
// DeepSeek-V2
|
||||||
|
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
||||||
};
|
};
|
||||||
std::vector<std::string> expected_output = {
|
std::vector<std::string> expected_output = {
|
||||||
// teknium/OpenHermes-2.5-Mistral-7B
|
// teknium/OpenHermes-2.5-Mistral-7B
|
||||||
@ -94,6 +98,10 @@ int main(void) {
|
|||||||
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||||
//Phi-3-vision
|
//Phi-3-vision
|
||||||
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||||
|
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||||
|
u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
|
||||||
|
// DeepSeek-V2
|
||||||
|
u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
|
||||||
};
|
};
|
||||||
std::vector<char> formatted_chat(1024);
|
std::vector<char> formatted_chat(1024);
|
||||||
int32_t res;
|
int32_t res;
|
||||||
|
Loading…
Reference in New Issue
Block a user