tokenizer : BPE fixes (#7530)

* Random test: add_bos_token, add_eos_token
* Random test: add BPE models for testing
* Custom regex split fails with codepoint 0
* Fix falcon punctuation regex
* Refactor llm_tokenizer_bpe: move code to constructor
* Move 'add_special_bos/eos' logic to llm_tokenizer_bpe
* Move tokenizer flags to vocab structure.
* Default values for special_add_bos/eos
* Build vocab.special_tokens_cache using vocab token types
* Generalize 'jina-v2' per token attributes
* Fix unicode whitespaces (deepseek-coder, deepseek-llm)
* Skip missing byte tokens (falcon)
* Better unicode data generation
* Replace char32_t with uint32_t
This commit is contained in:
jaime-m-p 2024-06-18 18:40:52 +02:00 committed by GitHub
parent 91c188d6c2
commit 37bef89433
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1283 additions and 1053 deletions

179
llama.cpp
View File

@ -2310,16 +2310,17 @@ struct llama_vocab {
id special_cls_id = -1; id special_cls_id = -1;
id special_mask_id = -1; id special_mask_id = -1;
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
id linefeed_id = 13; id linefeed_id = 13;
id special_prefix_id = -1; id special_prefix_id = -1;
id special_suffix_id = -1; id special_suffix_id = -1;
id special_middle_id = -1; id special_middle_id = -1;
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
bool add_space_prefix = true; // tokenizer flags
bool tokenizer_add_space_prefix = true;
bool tokenizer_add_bos = false;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const { int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
GGML_ASSERT(token_left.find(' ') == std::string::npos); GGML_ASSERT(token_left.find(' ') == std::string::npos);
@ -4770,7 +4771,7 @@ static void llm_load_vocab(
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) { if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
} // The default value of add_space_prefix is true. } // The default value of add_space_prefix is true.
} else if (tokenizer_model == "bert") { } else if (tokenizer_model == "bert") {
vocab.type = LLAMA_VOCAB_TYPE_WPM; vocab.type = LLAMA_VOCAB_TYPE_WPM;
@ -4783,13 +4784,13 @@ static void llm_load_vocab(
vocab.special_pad_id = 0; vocab.special_pad_id = 0;
vocab.special_cls_id = 101; vocab.special_cls_id = 101;
vocab.special_mask_id = 103; vocab.special_mask_id = 103;
vocab.add_space_prefix = false; vocab.tokenizer_add_space_prefix = false;
} else if (tokenizer_model == "gpt2") { } else if (tokenizer_model == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE; vocab.type = LLAMA_VOCAB_TYPE_BPE;
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) { if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
} }
// read bpe merges and populate bpe ranks // read bpe merges and populate bpe ranks
@ -4847,6 +4848,8 @@ static void llm_load_vocab(
tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-v3" ||
tokenizer_pre == "llama-bpe") { tokenizer_pre == "llama-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
vocab.tokenizer_ignore_merges = true;
vocab.tokenizer_add_bos = true;
} else if ( } else if (
tokenizer_pre == "deepseek-llm") { tokenizer_pre == "deepseek-llm") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
@ -4897,6 +4900,14 @@ static void llm_load_vocab(
} else { } else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
} }
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
vocab.tokenizer_add_bos = true;
vocab.tokenizer_add_eos = false;
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
vocab.tokenizer_add_bos = true;
vocab.tokenizer_add_eos = false;
} else { } else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} }
@ -5041,10 +5052,10 @@ static void llm_load_vocab(
bool temp = true; bool temp = true;
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
vocab.special_add_bos = int(temp); vocab.tokenizer_add_bos = temp;
} }
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
vocab.special_add_eos = int(temp); vocab.tokenizer_add_eos = temp;
} }
} }
@ -5144,7 +5155,7 @@ static void llm_load_vocab(
); );
// set attributes by model/tokenizer name // set attributes by model/tokenizer name
if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) { if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true); _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
} else if (_contains_any(model_name, {"phi-3", "phi3"})) { } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
for (auto id : vocab.cache_special_tokens) { for (auto id : vocab.cache_special_tokens) {
@ -13158,113 +13169,143 @@ struct llm_bigram_bpe {
}; };
struct llm_tokenizer_bpe { struct llm_tokenizer_bpe {
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {} llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1;
bool ignore_merges = false;
std::vector<std::string> word_collection;
switch (vocab.type) {
case LLAMA_VOCAB_TYPE_BPE:
switch (vocab.type_pre) { switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3: case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
ignore_merges = true; regex_exprs = {
word_collection = unicode_regex_split(text, {
// original regex from tokenizer.json // original regex from tokenizer.json
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_DBRX:
case LLAMA_VOCAB_PRE_TYPE_SMAUG: case LLAMA_VOCAB_PRE_TYPE_SMAUG:
word_collection = unicode_regex_split(text, { regex_exprs = {
// same as llama3 // same as llama3
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
word_collection = unicode_regex_split(text, { regex_exprs = {
"[\r\n]", "[\r\n]",
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
"\\s?[!-/:-~---‟ -。]+", "\\s?[!-/:-~---‟ -。]+",
"\\s+$", "\\s+$",
"[一-龥ࠀ-一가-퟿]+", "[一-龥ࠀ-一가-퟿]+",
"\\p{N}+", "\\p{N}+",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
word_collection = unicode_regex_split(text, { regex_exprs = {
"[\r\n]", "[\r\n]",
"\\s?\\p{L}+", "\\s?\\p{L}+",
"\\s?\\p{P}+", "\\s?\\p{P}+",
"[一-龥ࠀ-一가-퟿]+", "[一-龥ࠀ-一가-퟿]+",
"\\p{N}", "\\p{N}",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_FALCON: case LLAMA_VOCAB_PRE_TYPE_FALCON:
word_collection = unicode_regex_split(text, { regex_exprs = {
"[\\p{P}\\$\\+<=>\\^~\\|]+", "[\\p{P}\\$\\+<=>\\^~\\|`]+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"[0-9][0-9][0-9]", "[0-9][0-9][0-9]",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_MPT: case LLAMA_VOCAB_PRE_TYPE_MPT:
// TODO: MPT pre-tokenization regexes are unknown // TODO: MPT pre-tokenization regexes are unknown
// the following are close, but not exact. run the following: // the following are close, but not exact. run the following:
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed"); GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
word_collection = unicode_regex_split(text, { regex_exprs = {
"\\s?\\p{L}+", "\\s?\\p{L}+",
"\\s?\\p{P}+", "\\s?\\p{P}+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_STARCODER: case LLAMA_VOCAB_PRE_TYPE_STARCODER:
case LLAMA_VOCAB_PRE_TYPE_REFACT: case LLAMA_VOCAB_PRE_TYPE_REFACT:
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
word_collection = unicode_regex_split(text, { regex_exprs = {
"\\p{N}", "\\p{N}",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_GPT2: case LLAMA_VOCAB_PRE_TYPE_GPT2:
case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_OLMO:
word_collection = unicode_regex_split(text, { regex_exprs = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_QWEN2:
word_collection = unicode_regex_split(text, { regex_exprs = {
// original regex from tokenizer.json // original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_PORO:
word_collection = unicode_regex_split(text, { regex_exprs = {
" ?[^(\\s|.,!?…。,、।۔،)]+", " ?[^(\\s|.,!?…。,、।۔،)]+",
}); };
break; break;
default: default:
// default regex for BPE tokenization pre-processing // default regex for BPE tokenization pre-processing
word_collection = unicode_regex_split(text, { regex_exprs = {
"[\\p{P}\\$\\+<=>\\^~\\|]+", "[\\p{P}\\$\\+<=>\\^~\\|]+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"\\p{N}+", "\\p{N}+",
"[0-9][0-9][0-9]", "[0-9][0-9][0-9]",
}); };
break; break;
} }
break;
default:
GGML_ASSERT(false);
break;
} }
void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
output.push_back(token_id);
}
bool append_bos(std::vector<llama_vocab::id> & output) const {
if (vocab.tokenizer_add_bos) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
return true;
}
return false;
}
bool append_eos(std::vector<llama_vocab::id> & output) const {
if (vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
return true;
}
return false;
}
void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
LLAMA_LOG_WARN(
"%s: Added a EOS token to the prompt as specified by the model but the prompt "
"also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1;
const auto word_collection = unicode_regex_split(text, regex_exprs);
symbols_final.clear(); symbols_final.clear();
for (auto & word : word_collection) { for (auto & word : word_collection) {
@ -13274,7 +13315,7 @@ struct llm_tokenizer_bpe {
int index = 0; int index = 0;
size_t offset = 0; size_t offset = 0;
if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
offset = word.size(); offset = word.size();
} }
@ -13355,10 +13396,9 @@ struct llm_tokenizer_bpe {
for (auto j = str.begin(); j != str.end(); ++j) { for (auto j = str.begin(); j != str.end(); ++j) {
std::string byte_str(1, *j); std::string byte_str(1, *j);
auto token_multibyte = vocab.token_to_id.find(byte_str); auto token_multibyte = vocab.token_to_id.find(byte_str);
if (token_multibyte == vocab.token_to_id.end()) { if (token_multibyte != vocab.token_to_id.end()) {
throw std::runtime_error("ERROR: byte not found in vocab"); output.push_back(token_multibyte->second);
} }
output.push_back((*token_multibyte).second);
} }
} else { } else {
output.push_back((*token).second); output.push_back((*token).second);
@ -13397,6 +13437,8 @@ private:
const llama_vocab & vocab; const llama_vocab & vocab;
std::vector<std::string> regex_exprs;
std::vector<llm_symbol> symbols; std::vector<llm_symbol> symbols;
std::vector<llm_symbol> symbols_final; std::vector<llm_symbol> symbols_final;
@ -13677,7 +13719,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
bool is_prev_special = false; bool is_prev_special = false;
if (add_special && vocab.special_add_bos != 0) { if (add_special && vocab.tokenizer_add_bos) {
GGML_ASSERT(vocab.special_bos_id != -1); GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id); output.push_back(vocab.special_bos_id);
is_prev_special = true; is_prev_special = true;
@ -13687,7 +13729,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (vocab.add_space_prefix) { if (vocab.tokenizer_add_space_prefix) {
if (!output.size() || is_prev_special) { // prefix with space if first token if (!output.size() || is_prev_special) { // prefix with space if first token
raw_text = " " + raw_text; raw_text = " " + raw_text;
} }
@ -13705,23 +13747,24 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
} }
} }
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN( LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt " "%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__); "Are you sure this is what you want?\n", __FUNCTION__);
} }
if (add_special && vocab.special_add_eos == 1) { if (add_special && vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1); GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id); output.push_back(vocab.special_eos_id);
} }
} break; } break;
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
{ {
if (add_special && vocab.special_add_bos != 0) { llm_tokenizer_bpe tokenizer(vocab);
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id); if (add_special) {
tokenizer.append_bos(output);
} }
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
@ -13731,23 +13774,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif #endif
llm_tokenizer_bpe tokenizer(vocab);
tokenizer.tokenize(raw_text, output); tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token); tokenizer.append(fragment.token, output);
} }
} }
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { if (add_special) {
LLAMA_LOG_WARN( tokenizer.append_eos(output);
"%s: Added a BOS token to the prompt as specified by the model but the prompt " tokenizer.check_double_bos_eos(output);
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (add_special && vocab.special_add_eos == 1) {
GGML_ASSERT(vocab.special_add_eos != -1);
output.push_back(vocab.special_eos_id);
} }
} break; } break;
case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_WPM:
@ -18320,11 +18355,11 @@ llama_token llama_token_nl(const struct llama_model * model) {
} }
int32_t llama_add_bos_token(const struct llama_model * model) { int32_t llama_add_bos_token(const struct llama_model * model) {
return model->vocab.special_add_bos; return model->vocab.tokenizer_add_bos;
} }
int32_t llama_add_eos_token(const struct llama_model * model) { int32_t llama_add_eos_token(const struct llama_model * model) {
return model->vocab.special_add_eos; return model->vocab.tokenizer_add_eos;
} }
llama_token llama_token_prefix(const struct llama_model * model) { llama_token llama_token_prefix(const struct llama_model * model) {

View File

@ -1,83 +1,143 @@
import regex import array
import ctypes
import unicodedata import unicodedata
import requests
class CoodepointFlags (ctypes.Structure):
_fields_ = [ # see definition in unicode.h
("is_undefined", ctypes.c_uint16, 1),
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
]
assert (ctypes.sizeof(CoodepointFlags) == 2)
MAX_CODEPOINTS = 0x110000 MAX_CODEPOINTS = 0x110000
regex_number = regex.compile(r'\p{N}') UNICODE_DATA_URL = "https://www.unicode.org/Public/UCD/latest/ucd/UnicodeData.txt"
regex_letter = regex.compile(r'\p{L}')
regex_separator = regex.compile(r'\p{Z}')
regex_accent_mark = regex.compile(r'\p{M}')
regex_punctuation = regex.compile(r'\p{P}')
regex_symbol = regex.compile(r'\p{S}')
regex_control = regex.compile(r'\p{C}')
regex_whitespace = regex.compile(r'\s')
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
# see https://www.unicode.org/L2/L1999/UnicodeData.html
def unicode_data_iter():
res = requests.get(UNICODE_DATA_URL)
res.raise_for_status()
data = res.content.decode()
prev = []
for line in data.splitlines():
# ej: 0000;<control>;Cc;0;BN;;;;;N;NULL;;;;
line = line.split(";")
cpt = int(line[0], base=16)
assert cpt < MAX_CODEPOINTS
cpt_lower = int(line[-2] or "0", base=16)
assert cpt_lower < MAX_CODEPOINTS
cpt_upper = int(line[-3] or "0", base=16)
assert cpt_upper < MAX_CODEPOINTS
categ = line[2].strip()
assert len(categ) == 2
bidir = line[4].strip()
assert len(categ) == 2
name = line[1]
if name.endswith(", First>"):
prev = (cpt, cpt_lower, cpt_upper, categ, bidir)
continue
if name.endswith(", Last>"):
assert prev[1:] == (0, 0, categ, bidir)
for c in range(prev[0], cpt):
yield (c, cpt_lower, cpt_upper, categ, bidir)
yield (cpt, cpt_lower, cpt_upper, categ, bidir)
# see definition in unicode.h
CODEPOINT_FLAG_UNDEFINED = 0x0001 #
CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N}
CODEPOINT_FLAG_LETTER = 0x0004 # \p{L}
CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z}
CODEPOINT_FLAG_MARK = 0x0010 # \p{M}
CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P}
CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S}
CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C}
UNICODE_CATEGORY_TO_FLAG = {
"Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined
"Cc": CODEPOINT_FLAG_CONTROL, # Control
"Cf": CODEPOINT_FLAG_CONTROL, # Format
"Co": CODEPOINT_FLAG_CONTROL, # Private Use
"Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate
"Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter
"Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter
"Lo": CODEPOINT_FLAG_LETTER, # Other Letter
"Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter
"Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter
"L&": CODEPOINT_FLAG_LETTER, # Cased Letter
"Mc": CODEPOINT_FLAG_MARK, # Spacing Mark
"Me": CODEPOINT_FLAG_MARK, # Enclosing Mark
"Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark
"Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number
"Nl": CODEPOINT_FLAG_NUMBER, # Letter Number
"No": CODEPOINT_FLAG_NUMBER, # Other Number
"Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation
"Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation
"Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation
"Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation
"Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
"Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
"Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
"Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
"Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
"Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
"So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
"Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
"Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
"Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
}
codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
table_whitespace = [] table_whitespace = []
table_lowercase = [] table_lowercase = []
table_uppercase = [] table_uppercase = []
table_nfd = [] table_nfd = []
for codepoint in range(MAX_CODEPOINTS): for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
# convert codepoint to unicode character # convert codepoint to unicode character
char = chr(codepoint) char = chr(cpt)
# regex categories # codepoint category flags
flags = codepoint_flags[codepoint] codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
flags.is_number = bool(regex_number.match(char))
flags.is_letter = bool(regex_letter.match(char))
flags.is_separator = bool(regex_separator.match(char))
flags.is_accent_mark = bool(regex_accent_mark.match(char))
flags.is_punctuation = bool(regex_punctuation.match(char))
flags.is_symbol = bool(regex_symbol.match(char))
flags.is_control = bool(regex_control.match(char))
flags.is_undefined = bytes(flags)[0] == 0
assert (not flags.is_undefined)
# whitespaces
if bool(regex_whitespace.match(char)):
table_whitespace.append(codepoint)
# lowercase conversion # lowercase conversion
lower = ord(char.lower()[0]) if cpt_lower:
if codepoint != lower: table_lowercase.append((cpt, cpt_lower))
table_lowercase.append((codepoint, lower))
# uppercase conversion # uppercase conversion
upper = ord(char.upper()[0]) if cpt_upper:
if codepoint != upper: table_uppercase.append((cpt, cpt_upper))
table_uppercase.append((codepoint, upper))
# NFD normalization # NFD normalization
norm = ord(unicodedata.normalize('NFD', char)[0]) norm = ord(unicodedata.normalize('NFD', char)[0])
if codepoint != norm: if cpt != norm:
table_nfd.append((codepoint, norm)) table_nfd.append((cpt, norm))
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
table_whitespace.extend(range(0x0009, 0x000D + 1))
table_whitespace.extend(range(0x2000, 0x200A + 1))
table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
# sort by codepoint
table_whitespace.sort()
table_lowercase.sort()
table_uppercase.sort()
table_nfd.sort()
# group ranges with same flags # group ranges with same flags
ranges_flags = [(0, codepoint_flags[0])] # start, flags ranges_flags = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags): for codepoint, flags in enumerate(codepoint_flags):
if bytes(flags) != bytes(ranges_flags[-1][1]): if flags != ranges_flags[-1][1]:
ranges_flags.append((codepoint, flags)) ranges_flags.append((codepoint, flags))
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags())) ranges_flags.append((MAX_CODEPOINTS, 0x0000))
# group ranges with same nfd # group ranges with same nfd
@ -90,8 +150,8 @@ for codepoint, norm in table_nfd:
ranges_nfd[-1] = (start, codepoint, norm) ranges_nfd[-1] = (start, codepoint, norm)
# Generate 'unicode-data.cpp' # Generate 'unicode-data.cpp':
# python ./scripts//gen-unicode-data.py > unicode-data.cpp
def out(line=""): def out(line=""):
print(line, end='\n') # noqa print(line, end='\n') # noqa
@ -110,12 +170,12 @@ out("""\
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1") out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags: for codepoint, flags in ranges_flags:
flags = int.from_bytes(bytes(flags), "little")
out("{0x%06X, 0x%04X}," % (codepoint, flags)) out("{0x%06X, 0x%04X}," % (codepoint, flags))
out("};\n") out("};\n")
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {") out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
out(", ".join("0x%06X" % cpt for cpt in table_whitespace)) for codepoint in table_whitespace:
out("0x%06X," % codepoint)
out("};\n") out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {") out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")

View File

@ -11,13 +11,15 @@ import logging
import argparse import argparse
import subprocess import subprocess
import random import random
import unicodedata
from typing import Callable, Iterator from typing import Callable, Iterator
import cffi import cffi
from transformers import AutoTokenizer from transformers import AutoTokenizer
logger = logging.getLogger("test-tokenizer-random-bpe")
logger = logging.getLogger("test-tokenizer-random")
class LibLlama: class LibLlama:
@ -155,9 +157,14 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'Cửa Việt', # llama-3, ignore_merges = true 'Cửa Việt', # llama-3, ignore_merges = true
'<s>a', # Phi-3 fail '<s>a', # Phi-3 fail
'<unk><|endoftext|><s>', # Phi-3 fail '<unk><|endoftext|><s>', # Phi-3 fail
'a\na', # TODO: Bert fail 'a\na', # bert fail
'"`', # falcon
' \u2e4e', # falcon
'a\xa0\xa0\x00b', # jina-v2-es
'one <mask>', # jina-v2-es <mask> lstrip=true
'a </s> b', # rstrip phi-3 'a </s> b', # rstrip phi-3
'a <mask> b', # lstrip jina-v2 'a <mask> b', # lstrip jina-v2
'\xa0aC', # deepseek
] ]
@ -189,17 +196,23 @@ def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
for m in range(iterations): for m in range(iterations):
rand.seed(m) rand.seed(m)
words = rand.choices(all_tokens, k=500) words = rand.choices(all_tokens, k=500)
if words[0] == tokenizer.bos_token: # skip spam warning of double BOS if words and words[0] == tokenizer.bos_token: # skip spam warning of double BOS
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
words.pop(0) words.pop(0)
if tokenizer.add_bos_token: # drop all starting BOS if tokenizer.add_bos_token: # drop all starting BOS
words.pop(0) words.pop(0)
if words and words[-1] == tokenizer.eos_token: # skip spam warning of double EOS
while len(words) > 1 and words[-2] == tokenizer.eos_token: # leave one trailing EOS
words.pop(-1)
if tokenizer.add_bos_token: # drop all trailing EOS
words.pop(-1)
yield "".join(words) yield "".join(words)
def generator_random_chars(iterations=100) -> Iterator[str]: def generator_random_chars(iterations=100) -> Iterator[str]:
"""Brute force random text with simple characters""" """Brute force random text with simple characters"""
NUM_WORDS = 400
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5) WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
CHARS = list(sorted(set(""" CHARS = list(sorted(set("""
ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ
@ -213,12 +226,50 @@ def generator_random_chars(iterations=100) -> Iterator[str]:
for m in range(iterations): for m in range(iterations):
rand.seed(m) rand.seed(m)
text = [] text = []
num_words = rand.randint(300, 400) for _ in range(NUM_WORDS):
for i in range(num_words):
k = rand.randint(1, 7) k = rand.randint(1, 7)
word = rand.choices(CHARS, k=k) word = rand.choices(CHARS, k=k)
space = rand.choice(WHITESPACES) word.append(rand.choice(WHITESPACES))
text.append("".join(word) + space) text.append("".join(word))
yield "".join(text)
def generator_unicodes() -> Iterator[str]:
"""Iterate unicode characters"""
MAX_CODEPOINTS = 0x30000 # 0x110000
def _valid(cpt):
if cpt >= 0x30000: # unassigned and supplement­ary
return False
if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
return False
if unicodedata.category(chr(cpt)) == "Cn":
return False
return True
characters = [chr(cpt) for cpt in range(1, MAX_CODEPOINTS) if _valid(cpt)]
yield from characters
def generator_random_unicodes(iterations=100) -> Iterator[str]:
"""Brute force random text with unicode characters"""
NUM_WORDS = 200
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
characters = list(generator_unicodes())
rand = random.Random()
for m in range(iterations):
rand.seed(m)
text = []
for _ in range(NUM_WORDS):
k = rand.randint(1, 7)
word = rand.choices(characters, k=k)
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
yield "".join(text) yield "".join(text)
@ -256,25 +307,7 @@ def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[s
yield "".join(text) yield "".join(text)
def generator_random_bytes(iterations=100) -> Iterator[str]: def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
"""Brute force random bytes"""
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
rand = random.Random()
for m in range(iterations):
rand.seed(m)
text = []
num_words = rand.randint(300, 400)
for i in range(num_words):
k = rand.randint(1, 8)
word = [chr(r) for r in rand.randbytes(k) if r]
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
yield "".join(text)
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]): def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a, b) in enumerate(zip(ids1, ids2)): for i, (a, b) in enumerate(zip(ids1, ids2)):
@ -284,20 +317,34 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
return -1 return -1
return min(len(ids1), len(ids2)) return min(len(ids1), len(ids2))
t0 = time.perf_counter() t_tokenizer1 = 0
t_tokenizer2 = 0
t_start = time.perf_counter()
num_errors = 10
logger.info("%s: %s" % (generator.__name__, "ini")) logger.info("%s: %s" % (generator.__name__, "ini"))
for text in generator: for text in generator:
# print(repr(text), hex(ord(text[0])), text.encode())
t0 = time.perf_counter()
ids1 = func_tokenize1(text) ids1 = func_tokenize1(text)
t1 = time.perf_counter()
ids2 = func_tokenize2(text) ids2 = func_tokenize2(text)
t2 = time.perf_counter()
t_tokenizer1 += t1 - t0
t_tokenizer2 += t2 - t1
if ids1 != ids2: if ids1 != ids2:
i = find_first_mismatch(ids1, ids2) i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.info(" TokenIDs: " + str(ids1)) logger.error(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2)) logger.error(" Expected: " + str(ids2))
raise Exception() # raise Exception()
t1 = time.perf_counter() num_errors += 1
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0)) if num_errors > 10:
break
t_total = time.perf_counter() - t_start
logger.info("%s: end, tok1: %.3f tok2: %.3f total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
def main(argv: list[str] = None): def main(argv: list[str] = None):
@ -308,6 +355,7 @@ def main(argv: list[str] = None):
args = parser.parse_args(argv) args = parser.parse_args(argv)
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
logger.info(f"VOCABFILE: '{args.vocab_file}'")
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer) tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
@ -321,18 +369,22 @@ def main(argv: list[str] = None):
ids = func_tokenize2("a") ids = func_tokenize2("a")
assert 1 <= len(ids) <= 3 assert 1 <= len(ids) <= 3
add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0] add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
add_eos_token = len(ids) > 1 and tokenizer.eos_token_id == ids[-1]
tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token) tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token)
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True))) vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases()) compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_unicodes())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000)) compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_unicodes(10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
model.free() model.free()
@ -340,20 +392,40 @@ def main(argv: list[str] = None):
if __name__ == "__main__": if __name__ == "__main__":
# main() # main()
logging.basicConfig(
level = logging.DEBUG,
format = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
datefmt = "%Y-%m-%d %H:%M:%S",
filename = logger.name + ".log",
filemode = "a"
)
path_tokenizers = "./models/tokenizers/" path_tokenizers = "./models/tokenizers/"
path_vocab_format = "./models/ggml-vocab-%s.gguf" path_vocab_format = "./models/ggml-vocab-%s.gguf"
# import os # import os
# tokenizers = os.listdir(path_tokenizers) # tokenizers = os.listdir(path_tokenizers)
tokenizers = [ tokenizers = [
"llama-spm", # SPM # "llama-spm", # SPM
"phi-3", # SPM # "phi-3", # SPM
"jina-v2-en", # WPM # "bert-bge", # WPM
"bert-bge", # WPM # "jina-v2-en", # WPM
"gpt-2", # BPE
"llama-bpe", # BPE
"falcon", # BPE
"starcoder", # BPE
"jina-v2-es", # BPE
"jina-v2-de", # BPE
"jina-v2-code", # BPE
"smaug-bpe", # BPE
"phi-2", # BPE
"deepseek-coder", # BPE
"deepseek-llm", # BPE
] ]
for tokenizer in tokenizers: for tokenizer in tokenizers:
print("\n" + "=" * 50 + "\n" + tokenizer + "\n") # noqa logger.info("=" * 50)
logger.info(f"TOKENIZER: '{tokenizer}'")
vocab_file = path_vocab_format % tokenizer vocab_file = path_vocab_format % tokenizer
dir_tokenizer = path_tokenizers + "/" + tokenizer dir_tokenizer = path_tokenizers + "/" + tokenizer
main([vocab_file, dir_tokenizer, "--verbose"]) main([vocab_file, dir_tokenizer, "--verbose"])

File diff suppressed because it is too large Load Diff

View File

@ -226,8 +226,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
assert(offset_end <= cpts.size()); assert(offset_end <= cpts.size());
start = offset_end; start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&] (const size_t pos) -> uint32_t { auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
@ -309,7 +310,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
} }
// regex: \s+(?!\S) // regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) { if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1; pos += num_whitespaces - 1;
_add_token(pos); _add_token(pos);
continue; continue;
@ -344,8 +345,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
assert(offset_end <= cpts.size()); assert(offset_end <= cpts.size());
start = offset_end; start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&] (const size_t pos) -> uint32_t { auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
@ -450,7 +452,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: \s+(?!\S) // regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) { if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1; pos += num_whitespaces - 1;
_add_token(pos); _add_token(pos);
continue; continue;
@ -679,10 +681,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue; continue;
} }
const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag(); const auto flags = unicode_cpt_flags(cpts[i]);
if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) { if (flags.is_whitespace) {
text_collapsed[i] = k_ucat_cpt.at(cpt_flag); //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
} else { } else {
text_collapsed[i] = (char) 0xD0; // fallback text_collapsed[i] = (char) 0xD0; // fallback
} }
@ -766,9 +772,16 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
} else { } else {
// no unicode category used, we can use std::wregex directly // no unicode category used, we can use std::wregex directly
const std::wstring wtext = unicode_wstring_from_utf8(text);
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
wtext[i] = 0x0B;
}
}
//printf("text: %s\n", text.c_str()); //printf("text: %s\n", text.c_str());
//printf("regex_expr: %s\n", regex_expr.c_str()); //printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);