mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 13:27:21 +01:00
llama : pre-tokenize non-special user-defined tokens first
This commit is contained in:
parent
ac0f33c920
commit
d5d30b20c3
@ -5495,28 +5495,6 @@ static void llm_load_vocab(
|
|||||||
vocab.token_to_id[word] = i;
|
vocab.token_to_id[word] = i;
|
||||||
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
|
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
|
||||||
|
|
||||||
// TODO: properly handle pre-normalized added_tokens and remove this
|
|
||||||
// handle space tokens with dual tokens,
|
|
||||||
// like the pre-normalized added_tokens
|
|
||||||
// of neox-style tokenizers (mpt, olmo, stablelm, etc)
|
|
||||||
if (word.find(' ') != std::string::npos) {
|
|
||||||
// same as in the internal `unicode_byte_encoding_process`
|
|
||||||
// TODO: extract and expose this in some unicode_* function
|
|
||||||
std::string text_utf;
|
|
||||||
auto utf_word = unicode_cpts_from_utf8(word);
|
|
||||||
for (size_t i = 0; i < utf_word.size(); ++i) {
|
|
||||||
text_utf += unicode_cpt_to_utf8(utf_word[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string encoded_token;
|
|
||||||
for (char & c : text_utf) {
|
|
||||||
encoded_token += unicode_byte_to_utf8(c);
|
|
||||||
}
|
|
||||||
|
|
||||||
// override token id
|
|
||||||
vocab.token_to_id[encoded_token] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto & token_data = vocab.id_to_token[i];
|
auto & token_data = vocab.id_to_token[i];
|
||||||
token_data.text = std::move(word);
|
token_data.text = std::move(word);
|
||||||
token_data.score = scores ? scores[i] : 0.0f;
|
token_data.score = scores ? scores[i] : 0.0f;
|
||||||
@ -5534,6 +5512,13 @@ static void llm_load_vocab(
|
|||||||
default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
|
default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && token_data.text.find('<') && token_data.text.rfind('>')) {
|
||||||
|
// Some models mark some added tokens which ought to be control tokens as not special.
|
||||||
|
// (e.g. command-r, command-r-plus, deepseek-coder)
|
||||||
|
// TODO: should this be fixed in the convert script instead?
|
||||||
|
token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
|
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
|
||||||
|
|
||||||
@ -15426,13 +15411,6 @@ struct llm_tokenizer_bpe {
|
|||||||
"[0-9][0-9][0-9]",
|
"[0-9][0-9][0-9]",
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
|
||||||
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
|
||||||
regex_exprs = {
|
|
||||||
"[ ]{2,24}", // the spaces from the added_tokens are split separately
|
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
|
||||||
};
|
|
||||||
break;
|
|
||||||
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
||||||
@ -15442,6 +15420,8 @@ struct llm_tokenizer_bpe {
|
|||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_JAIS:
|
case LLAMA_VOCAB_PRE_TYPE_JAIS:
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
@ -15523,10 +15503,6 @@ struct llm_tokenizer_bpe {
|
|||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
int final_prev_index = -1;
|
int final_prev_index = -1;
|
||||||
|
|
||||||
// FIXME: pre-tokenize added_tokens (user-defined tokens) before other pre-tokenization
|
|
||||||
// ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
|
|
||||||
// (useful for neox-style tokenizers)
|
|
||||||
|
|
||||||
const auto word_collection = unicode_regex_split(text, regex_exprs);
|
const auto word_collection = unicode_regex_split(text, regex_exprs);
|
||||||
|
|
||||||
symbols_final.clear();
|
symbols_final.clear();
|
||||||
@ -16192,12 +16168,20 @@ struct fragment_buffer_variant {
|
|||||||
|
|
||||||
// #define PRETOKENIZERDEBUG
|
// #define PRETOKENIZERDEBUG
|
||||||
|
|
||||||
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
|
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
|
||||||
// for each special token
|
// for each special token
|
||||||
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
|
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
|
||||||
const auto & data = vocab.id_to_token[special_id];
|
const auto & data = vocab.id_to_token[special_id];
|
||||||
const auto & special_token = data.text;
|
const auto & special_token = data.text;
|
||||||
|
|
||||||
|
if (!parse_special && (data.attr & LLAMA_TOKEN_ATTR_CONTROL)) {
|
||||||
|
// Only ignore control tokens when parse_special == false
|
||||||
|
continue;
|
||||||
|
// User-defined tokens are still pre-tokenized before everything else
|
||||||
|
// ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
|
||||||
|
// This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
// for each text fragment
|
// for each text fragment
|
||||||
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
|
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
|
||||||
while (it != buffer.end()) {
|
while (it != buffer.end()) {
|
||||||
@ -16310,7 +16294,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||||||
|
|
||||||
if (!raw_text.empty()) {
|
if (!raw_text.empty()) {
|
||||||
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
|
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
|
||||||
if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
|
tokenizer_st_partition(vocab, fragment_buffer, parse_special);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (vocab.type) {
|
switch (vocab.type) {
|
||||||
|
@ -195,7 +195,7 @@ int main(int argc, char **argv) {
|
|||||||
const bool add_special = false;
|
const bool add_special = false;
|
||||||
|
|
||||||
for (const auto & test_kv : k_tests) {
|
for (const auto & test_kv : k_tests) {
|
||||||
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, true);
|
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special);
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("src: '%s'\n", test_kv.first.c_str());
|
printf("src: '%s'\n", test_kv.first.c_str());
|
||||||
@ -253,7 +253,7 @@ int main(int argc, char **argv) {
|
|||||||
{
|
{
|
||||||
const auto t_start = ggml_time_us();
|
const auto t_start = ggml_time_us();
|
||||||
|
|
||||||
res = llama_tokenize(ctx, text, add_special, true);
|
res = llama_tokenize(ctx, text, add_special);
|
||||||
|
|
||||||
const auto t_end = ggml_time_us();
|
const auto t_end = ggml_time_us();
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user