tokenizer : BPE fixes (#7530)

* Random test: add_bos_token, add_eos_token
* Random test: add BPE models for testing
* Custom regex split fails with codepoint 0
* Fix falcon punctuation regex
* Refactor llm_tokenizer_bpe: move code to constructor
* Move 'add_special_bos/eos' logic to llm_tokenizer_bpe
* Move tokenizer flags to vocab structure.
* Default values for special_add_bos/eos
* Build vocab.special_tokens_cache using vocab token types
* Generalize 'jina-v2' per token attributes
* Fix unicode whitespaces (deepseek-coder, deepseek-llm)
* Skip missing byte tokens (falcon)
* Better unicode data generation
* Replace char32_t with uint32_t
This commit is contained in:
jaime-m-p 2024-06-18 18:40:52 +02:00 committed by GitHub
parent 91c188d6c2
commit 37bef89433
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1283 additions and 1053 deletions

309
llama.cpp
View File

@ -2310,16 +2310,17 @@ struct llama_vocab {
id special_cls_id = -1;
id special_mask_id = -1;
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
id linefeed_id = 13;
id special_prefix_id = -1;
id special_suffix_id = -1;
id special_middle_id = -1;
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
bool add_space_prefix = true;
// tokenizer flags
bool tokenizer_add_space_prefix = true;
bool tokenizer_add_bos = false;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
GGML_ASSERT(token_left.find(' ') == std::string::npos);
@ -4770,7 +4771,7 @@ static void llm_load_vocab(
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
} // The default value of add_space_prefix is true.
} else if (tokenizer_model == "bert") {
vocab.type = LLAMA_VOCAB_TYPE_WPM;
@ -4783,13 +4784,13 @@ static void llm_load_vocab(
vocab.special_pad_id = 0;
vocab.special_cls_id = 101;
vocab.special_mask_id = 103;
vocab.add_space_prefix = false;
vocab.tokenizer_add_space_prefix = false;
} else if (tokenizer_model == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE;
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
}
// read bpe merges and populate bpe ranks
@ -4847,6 +4848,8 @@ static void llm_load_vocab(
tokenizer_pre == "llama-v3" ||
tokenizer_pre == "llama-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
vocab.tokenizer_ignore_merges = true;
vocab.tokenizer_add_bos = true;
} else if (
tokenizer_pre == "deepseek-llm") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
@ -4897,6 +4900,14 @@ static void llm_load_vocab(
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
vocab.tokenizer_add_bos = true;
vocab.tokenizer_add_eos = false;
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
vocab.tokenizer_add_bos = true;
vocab.tokenizer_add_eos = false;
} else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
@ -5041,10 +5052,10 @@ static void llm_load_vocab(
bool temp = true;
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
vocab.special_add_bos = int(temp);
vocab.tokenizer_add_bos = temp;
}
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
vocab.special_add_eos = int(temp);
vocab.tokenizer_add_eos = temp;
}
}
@ -5144,7 +5155,7 @@ static void llm_load_vocab(
);
// set attributes by model/tokenizer name
if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
for (auto id : vocab.cache_special_tokens) {
@ -13158,112 +13169,142 @@ struct llm_bigram_bpe {
};
struct llm_tokenizer_bpe {
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
regex_exprs = {
// original regex from tokenizer.json
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_DBRX:
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
regex_exprs = {
// same as llama3
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
regex_exprs = {
"[\r\n]",
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
"\\s?[!-/:-~---‟ -。]+",
"\\s+$",
"[一-龥ࠀ-一가-퟿]+",
"\\p{N}+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
regex_exprs = {
"[\r\n]",
"\\s?\\p{L}+",
"\\s?\\p{P}+",
"[一-龥ࠀ-一가-퟿]+",
"\\p{N}",
};
break;
case LLAMA_VOCAB_PRE_TYPE_FALCON:
regex_exprs = {
"[\\p{P}\\$\\+<=>\\^~\\|`]+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"[0-9][0-9][0-9]",
};
break;
case LLAMA_VOCAB_PRE_TYPE_MPT:
// TODO: MPT pre-tokenization regexes are unknown
// the following are close, but not exact. run the following:
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
regex_exprs = {
"\\s?\\p{L}+",
"\\s?\\p{P}+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
break;
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
case LLAMA_VOCAB_PRE_TYPE_REFACT:
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
regex_exprs = {
"\\p{N}",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
break;
case LLAMA_VOCAB_PRE_TYPE_GPT2:
case LLAMA_VOCAB_PRE_TYPE_OLMO:
regex_exprs = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
break;
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
regex_exprs = {
// original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_PORO:
regex_exprs = {
" ?[^(\\s|.,!?…。,、।۔،)]+",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
"[\\p{P}\\$\\+<=>\\^~\\|]+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"\\p{N}+",
"[0-9][0-9][0-9]",
};
break;
}
}
void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
output.push_back(token_id);
}
bool append_bos(std::vector<llama_vocab::id> & output) const {
if (vocab.tokenizer_add_bos) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
return true;
}
return false;
}
bool append_eos(std::vector<llama_vocab::id> & output) const {
if (vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
return true;
}
return false;
}
void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
LLAMA_LOG_WARN(
"%s: Added a EOS token to the prompt as specified by the model but the prompt "
"also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1;
bool ignore_merges = false;
std::vector<std::string> word_collection;
switch (vocab.type) {
case LLAMA_VOCAB_TYPE_BPE:
switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
ignore_merges = true;
word_collection = unicode_regex_split(text, {
// original regex from tokenizer.json
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
});
break;
case LLAMA_VOCAB_PRE_TYPE_DBRX:
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
word_collection = unicode_regex_split(text, {
// same as llama3
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
});
break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
word_collection = unicode_regex_split(text, {
"[\r\n]",
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
"\\s?[!-/:-~---‟ -。]+",
"\\s+$",
"[一-龥ࠀ-一가-퟿]+",
"\\p{N}+",
});
break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
word_collection = unicode_regex_split(text, {
"[\r\n]",
"\\s?\\p{L}+",
"\\s?\\p{P}+",
"[一-龥ࠀ-一가-퟿]+",
"\\p{N}",
});
break;
case LLAMA_VOCAB_PRE_TYPE_FALCON:
word_collection = unicode_regex_split(text, {
"[\\p{P}\\$\\+<=>\\^~\\|]+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"[0-9][0-9][0-9]",
});
break;
case LLAMA_VOCAB_PRE_TYPE_MPT:
// TODO: MPT pre-tokenization regexes are unknown
// the following are close, but not exact. run the following:
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
word_collection = unicode_regex_split(text, {
"\\s?\\p{L}+",
"\\s?\\p{P}+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
});
break;
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
case LLAMA_VOCAB_PRE_TYPE_REFACT:
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
word_collection = unicode_regex_split(text, {
"\\p{N}",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
});
break;
case LLAMA_VOCAB_PRE_TYPE_GPT2:
case LLAMA_VOCAB_PRE_TYPE_OLMO:
word_collection = unicode_regex_split(text, {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
});
break;
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
word_collection = unicode_regex_split(text, {
// original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
});
break;
case LLAMA_VOCAB_PRE_TYPE_PORO:
word_collection = unicode_regex_split(text, {
" ?[^(\\s|.,!?…。,、।۔،)]+",
});
break;
default:
// default regex for BPE tokenization pre-processing
word_collection = unicode_regex_split(text, {
"[\\p{P}\\$\\+<=>\\^~\\|]+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"\\p{N}+",
"[0-9][0-9][0-9]",
});
break;
}
break;
default:
GGML_ASSERT(false);
break;
}
const auto word_collection = unicode_regex_split(text, regex_exprs);
symbols_final.clear();
@ -13274,7 +13315,7 @@ struct llm_tokenizer_bpe {
int index = 0;
size_t offset = 0;
if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
offset = word.size();
}
@ -13355,10 +13396,9 @@ struct llm_tokenizer_bpe {
for (auto j = str.begin(); j != str.end(); ++j) {
std::string byte_str(1, *j);
auto token_multibyte = vocab.token_to_id.find(byte_str);
if (token_multibyte == vocab.token_to_id.end()) {
throw std::runtime_error("ERROR: byte not found in vocab");
if (token_multibyte != vocab.token_to_id.end()) {
output.push_back(token_multibyte->second);
}
output.push_back((*token_multibyte).second);
}
} else {
output.push_back((*token).second);
@ -13397,6 +13437,8 @@ private:
const llama_vocab & vocab;
std::vector<std::string> regex_exprs;
std::vector<llm_symbol> symbols;
std::vector<llm_symbol> symbols_final;
@ -13677,7 +13719,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
bool is_prev_special = false;
if (add_special && vocab.special_add_bos != 0) {
if (add_special && vocab.tokenizer_add_bos) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
is_prev_special = true;
@ -13687,7 +13729,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (vocab.add_space_prefix) {
if (vocab.tokenizer_add_space_prefix) {
if (!output.size() || is_prev_special) { // prefix with space if first token
raw_text = " " + raw_text;
}
@ -13705,23 +13747,24 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
}
}
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (add_special && vocab.special_add_eos == 1) {
if (add_special && vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
}
} break;
case LLAMA_VOCAB_TYPE_BPE:
{
if (add_special && vocab.special_add_bos != 0) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
llm_tokenizer_bpe tokenizer(vocab);
if (add_special) {
tokenizer.append_bos(output);
}
for (const auto & fragment : fragment_buffer) {
@ -13731,23 +13774,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
llm_tokenizer_bpe tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
tokenizer.append(fragment.token, output);
}
}
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (add_special && vocab.special_add_eos == 1) {
GGML_ASSERT(vocab.special_add_eos != -1);
output.push_back(vocab.special_eos_id);
if (add_special) {
tokenizer.append_eos(output);
tokenizer.check_double_bos_eos(output);
}
} break;
case LLAMA_VOCAB_TYPE_WPM:
@ -18320,11 +18355,11 @@ llama_token llama_token_nl(const struct llama_model * model) {
}
int32_t llama_add_bos_token(const struct llama_model * model) {
return model->vocab.special_add_bos;
return model->vocab.tokenizer_add_bos;
}
int32_t llama_add_eos_token(const struct llama_model * model) {
return model->vocab.special_add_eos;
return model->vocab.tokenizer_add_eos;
}
llama_token llama_token_prefix(const struct llama_model * model) {

View File

@ -1,83 +1,143 @@
import regex
import ctypes
import array
import unicodedata
class CoodepointFlags (ctypes.Structure):
_fields_ = [ # see definition in unicode.h
("is_undefined", ctypes.c_uint16, 1),
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
]
assert (ctypes.sizeof(CoodepointFlags) == 2)
import requests
MAX_CODEPOINTS = 0x110000
regex_number = regex.compile(r'\p{N}')
regex_letter = regex.compile(r'\p{L}')
regex_separator = regex.compile(r'\p{Z}')
regex_accent_mark = regex.compile(r'\p{M}')
regex_punctuation = regex.compile(r'\p{P}')
regex_symbol = regex.compile(r'\p{S}')
regex_control = regex.compile(r'\p{C}')
regex_whitespace = regex.compile(r'\s')
UNICODE_DATA_URL = "https://www.unicode.org/Public/UCD/latest/ucd/UnicodeData.txt"
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
# see https://www.unicode.org/L2/L1999/UnicodeData.html
def unicode_data_iter():
res = requests.get(UNICODE_DATA_URL)
res.raise_for_status()
data = res.content.decode()
prev = []
for line in data.splitlines():
# ej: 0000;<control>;Cc;0;BN;;;;;N;NULL;;;;
line = line.split(";")
cpt = int(line[0], base=16)
assert cpt < MAX_CODEPOINTS
cpt_lower = int(line[-2] or "0", base=16)
assert cpt_lower < MAX_CODEPOINTS
cpt_upper = int(line[-3] or "0", base=16)
assert cpt_upper < MAX_CODEPOINTS
categ = line[2].strip()
assert len(categ) == 2
bidir = line[4].strip()
assert len(categ) == 2
name = line[1]
if name.endswith(", First>"):
prev = (cpt, cpt_lower, cpt_upper, categ, bidir)
continue
if name.endswith(", Last>"):
assert prev[1:] == (0, 0, categ, bidir)
for c in range(prev[0], cpt):
yield (c, cpt_lower, cpt_upper, categ, bidir)
yield (cpt, cpt_lower, cpt_upper, categ, bidir)
# see definition in unicode.h
CODEPOINT_FLAG_UNDEFINED = 0x0001 #
CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N}
CODEPOINT_FLAG_LETTER = 0x0004 # \p{L}
CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z}
CODEPOINT_FLAG_MARK = 0x0010 # \p{M}
CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P}
CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S}
CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C}
UNICODE_CATEGORY_TO_FLAG = {
"Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined
"Cc": CODEPOINT_FLAG_CONTROL, # Control
"Cf": CODEPOINT_FLAG_CONTROL, # Format
"Co": CODEPOINT_FLAG_CONTROL, # Private Use
"Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate
"Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter
"Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter
"Lo": CODEPOINT_FLAG_LETTER, # Other Letter
"Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter
"Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter
"L&": CODEPOINT_FLAG_LETTER, # Cased Letter
"Mc": CODEPOINT_FLAG_MARK, # Spacing Mark
"Me": CODEPOINT_FLAG_MARK, # Enclosing Mark
"Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark
"Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number
"Nl": CODEPOINT_FLAG_NUMBER, # Letter Number
"No": CODEPOINT_FLAG_NUMBER, # Other Number
"Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation
"Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation
"Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation
"Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation
"Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
"Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
"Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
"Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
"Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
"Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
"So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
"Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
"Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
"Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
}
codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
table_whitespace = []
table_lowercase = []
table_uppercase = []
table_nfd = []
for codepoint in range(MAX_CODEPOINTS):
for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
# convert codepoint to unicode character
char = chr(codepoint)
char = chr(cpt)
# regex categories
flags = codepoint_flags[codepoint]
flags.is_number = bool(regex_number.match(char))
flags.is_letter = bool(regex_letter.match(char))
flags.is_separator = bool(regex_separator.match(char))
flags.is_accent_mark = bool(regex_accent_mark.match(char))
flags.is_punctuation = bool(regex_punctuation.match(char))
flags.is_symbol = bool(regex_symbol.match(char))
flags.is_control = bool(regex_control.match(char))
flags.is_undefined = bytes(flags)[0] == 0
assert (not flags.is_undefined)
# whitespaces
if bool(regex_whitespace.match(char)):
table_whitespace.append(codepoint)
# codepoint category flags
codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
# lowercase conversion
lower = ord(char.lower()[0])
if codepoint != lower:
table_lowercase.append((codepoint, lower))
if cpt_lower:
table_lowercase.append((cpt, cpt_lower))
# uppercase conversion
upper = ord(char.upper()[0])
if codepoint != upper:
table_uppercase.append((codepoint, upper))
if cpt_upper:
table_uppercase.append((cpt, cpt_upper))
# NFD normalization
norm = ord(unicodedata.normalize('NFD', char)[0])
if codepoint != norm:
table_nfd.append((codepoint, norm))
if cpt != norm:
table_nfd.append((cpt, norm))
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
table_whitespace.extend(range(0x0009, 0x000D + 1))
table_whitespace.extend(range(0x2000, 0x200A + 1))
table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
# sort by codepoint
table_whitespace.sort()
table_lowercase.sort()
table_uppercase.sort()
table_nfd.sort()
# group ranges with same flags
ranges_flags = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags):
if bytes(flags) != bytes(ranges_flags[-1][1]):
if flags != ranges_flags[-1][1]:
ranges_flags.append((codepoint, flags))
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
ranges_flags.append((MAX_CODEPOINTS, 0x0000))
# group ranges with same nfd
@ -90,8 +150,8 @@ for codepoint, norm in table_nfd:
ranges_nfd[-1] = (start, codepoint, norm)
# Generate 'unicode-data.cpp'
# Generate 'unicode-data.cpp':
# python ./scripts//gen-unicode-data.py > unicode-data.cpp
def out(line=""):
print(line, end='\n') # noqa
@ -110,12 +170,12 @@ out("""\
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags:
flags = int.from_bytes(bytes(flags), "little")
out("{0x%06X, 0x%04X}," % (codepoint, flags))
out("};\n")
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
for codepoint in table_whitespace:
out("0x%06X," % codepoint)
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")

View File

@ -11,13 +11,15 @@ import logging
import argparse
import subprocess
import random
import unicodedata
from typing import Callable, Iterator
import cffi
from transformers import AutoTokenizer
logger = logging.getLogger("test-tokenizer-random-bpe")
logger = logging.getLogger("test-tokenizer-random")
class LibLlama:
@ -155,9 +157,14 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'Cửa Việt', # llama-3, ignore_merges = true
'<s>a', # Phi-3 fail
'<unk><|endoftext|><s>', # Phi-3 fail
'a\na', # TODO: Bert fail
'a </s> b', # rstrip phi-3
'a <mask> b', # lstrip jina-v2
'a\na', # bert fail
'"`', # falcon
' \u2e4e', # falcon
'a\xa0\xa0\x00b', # jina-v2-es
'one <mask>', # jina-v2-es <mask> lstrip=true
'a </s> b', # rstrip phi-3
'a <mask> b', # lstrip jina-v2
'\xa0aC', # deepseek
]
@ -189,17 +196,23 @@ def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
for m in range(iterations):
rand.seed(m)
words = rand.choices(all_tokens, k=500)
if words[0] == tokenizer.bos_token: # skip spam warning of double BOS
if words and words[0] == tokenizer.bos_token: # skip spam warning of double BOS
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
words.pop(0)
if tokenizer.add_bos_token: # drop all starting BOS
words.pop(0)
if words and words[-1] == tokenizer.eos_token: # skip spam warning of double EOS
while len(words) > 1 and words[-2] == tokenizer.eos_token: # leave one trailing EOS
words.pop(-1)
if tokenizer.add_bos_token: # drop all trailing EOS
words.pop(-1)
yield "".join(words)
def generator_random_chars(iterations=100) -> Iterator[str]:
"""Brute force random text with simple characters"""
NUM_WORDS = 400
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
CHARS = list(sorted(set("""
ABCDEFGHIJKLMNOPQRSTUVWXYZ
@ -213,12 +226,50 @@ def generator_random_chars(iterations=100) -> Iterator[str]:
for m in range(iterations):
rand.seed(m)
text = []
num_words = rand.randint(300, 400)
for i in range(num_words):
for _ in range(NUM_WORDS):
k = rand.randint(1, 7)
word = rand.choices(CHARS, k=k)
space = rand.choice(WHITESPACES)
text.append("".join(word) + space)
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
yield "".join(text)
def generator_unicodes() -> Iterator[str]:
"""Iterate unicode characters"""
MAX_CODEPOINTS = 0x30000 # 0x110000
def _valid(cpt):
if cpt >= 0x30000: # unassigned and supplement­ary
return False
if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
return False
if unicodedata.category(chr(cpt)) == "Cn":
return False
return True
characters = [chr(cpt) for cpt in range(1, MAX_CODEPOINTS) if _valid(cpt)]
yield from characters
def generator_random_unicodes(iterations=100) -> Iterator[str]:
"""Brute force random text with unicode characters"""
NUM_WORDS = 200
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
characters = list(generator_unicodes())
rand = random.Random()
for m in range(iterations):
rand.seed(m)
text = []
for _ in range(NUM_WORDS):
k = rand.randint(1, 7)
word = rand.choices(characters, k=k)
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
yield "".join(text)
@ -256,25 +307,7 @@ def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[s
yield "".join(text)
def generator_random_bytes(iterations=100) -> Iterator[str]:
"""Brute force random bytes"""
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
rand = random.Random()
for m in range(iterations):
rand.seed(m)
text = []
num_words = rand.randint(300, 400)
for i in range(num_words):
k = rand.randint(1, 8)
word = [chr(r) for r in rand.randbytes(k) if r]
word.append(rand.choice(WHITESPACES))
text.append("".join(word))
yield "".join(text)
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a, b) in enumerate(zip(ids1, ids2)):
@ -284,20 +317,34 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
return -1
return min(len(ids1), len(ids2))
t0 = time.perf_counter()
t_tokenizer1 = 0
t_tokenizer2 = 0
t_start = time.perf_counter()
num_errors = 10
logger.info("%s: %s" % (generator.__name__, "ini"))
for text in generator:
# print(repr(text), hex(ord(text[0])), text.encode())
t0 = time.perf_counter()
ids1 = func_tokenize1(text)
t1 = time.perf_counter()
ids2 = func_tokenize2(text)
t2 = time.perf_counter()
t_tokenizer1 += t1 - t0
t_tokenizer2 += t2 - t1
if ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.info(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2))
raise Exception()
t1 = time.perf_counter()
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
logger.error(" TokenIDs: " + str(ids1))
logger.error(" Expected: " + str(ids2))
# raise Exception()
num_errors += 1
if num_errors > 10:
break
t_total = time.perf_counter() - t_start
logger.info("%s: end, tok1: %.3f tok2: %.3f total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
def main(argv: list[str] = None):
@ -307,7 +354,8 @@ def main(argv: list[str] = None):
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args(argv)
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
logger.info(f"VOCABFILE: '{args.vocab_file}'")
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
@ -321,18 +369,22 @@ def main(argv: list[str] = None):
ids = func_tokenize2("a")
assert 1 <= len(ids) <= 3
add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
add_eos_token = len(ids) > 1 and tokenizer.eos_token_id == ids[-1]
tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token)
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
compare_tokenizers(func_tokenize1, func_tokenize2, generator_unicodes())
compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_unicodes(10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
model.free()
@ -340,20 +392,40 @@ def main(argv: list[str] = None):
if __name__ == "__main__":
# main()
logging.basicConfig(
level = logging.DEBUG,
format = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
datefmt = "%Y-%m-%d %H:%M:%S",
filename = logger.name + ".log",
filemode = "a"
)
path_tokenizers = "./models/tokenizers/"
path_vocab_format = "./models/ggml-vocab-%s.gguf"
# import os
# tokenizers = os.listdir(path_tokenizers)
tokenizers = [
"llama-spm", # SPM
"phi-3", # SPM
"jina-v2-en", # WPM
"bert-bge", # WPM
# "llama-spm", # SPM
# "phi-3", # SPM
# "bert-bge", # WPM
# "jina-v2-en", # WPM
"gpt-2", # BPE
"llama-bpe", # BPE
"falcon", # BPE
"starcoder", # BPE
"jina-v2-es", # BPE
"jina-v2-de", # BPE
"jina-v2-code", # BPE
"smaug-bpe", # BPE
"phi-2", # BPE
"deepseek-coder", # BPE
"deepseek-llm", # BPE
]
for tokenizer in tokenizers:
print("\n" + "=" * 50 + "\n" + tokenizer + "\n") # noqa
logger.info("=" * 50)
logger.info(f"TOKENIZER: '{tokenizer}'")
vocab_file = path_vocab_format % tokenizer
dir_tokenizer = path_tokenizers + "/" + tokenizer
main([vocab_file, dir_tokenizer, "--verbose"])

File diff suppressed because it is too large Load Diff

View File

@ -226,8 +226,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
assert(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
@ -309,7 +310,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
}
// regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
@ -344,8 +345,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
assert(offset_end <= cpts.size());
start = offset_end;
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
@ -450,7 +452,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
}
// regex: \s+(?!\S)
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
@ -679,10 +681,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue;
}
const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
const auto flags = unicode_cpt_flags(cpts[i]);
if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
} else {
text_collapsed[i] = (char) 0xD0; // fallback
}
@ -766,9 +772,16 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
} else {
// no unicode category used, we can use std::wregex directly
const std::wstring wtext = unicode_wstring_from_utf8(text);
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
wtext[i] = 0x0B;
}
}
//printf("text: %s\n", text.c_str());
//printf("regex_expr: %s\n", regex_expr.c_str());
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);