llama : handle added special tokens like HF does

Now the BERT tokenizer actually uses the SEP and CLS tokens from
SpecialVocab.
This commit is contained in:
Jared Van Bortel 2024-03-27 16:59:49 -04:00
parent 748fc8baa3
commit 8803582721
3 changed files with 115 additions and 68 deletions

View File

@ -123,10 +123,10 @@ int main(int argc, char ** argv) {
inputs.push_back(inp); inputs.push_back(inp);
} }
// add eos if not present // add SEP if not present
for (auto & inp : inputs) { for (auto & inp : inputs) {
if (inp.empty() || inp.back() != llama_token_eos(model)) { if (inp.empty() || inp.back() != llama_token_sep(model)) {
inp.push_back(llama_token_eos(model)); inp.push_back(llama_token_sep(model));
} }
} }

167
llama.cpp
View File

@ -315,6 +315,8 @@ enum llm_kv {
LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_UNK_ID,
LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_SEP_ID,
LLM_KV_TOKENIZER_PAD_ID, LLM_KV_TOKENIZER_PAD_ID,
LLM_KV_TOKENIZER_CLS_ID,
LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_PREFIX, LLM_KV_TOKENIZER_ADD_PREFIX,
@ -384,6 +386,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
{ LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" },
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
@ -1981,11 +1985,13 @@ struct llama_vocab {
std::map<std::pair<std::string, std::string>, int> bpe_ranks; std::map<std::pair<std::string, std::string>, int> bpe_ranks;
// default LLaMA special tokens // default LLaMA special tokens
id special_bos_id = 1; id special_bos_id = 1;
id special_eos_id = 2; id special_eos_id = 2;
id special_unk_id = 0; id special_unk_id = 0;
id special_sep_id = -1; id special_sep_id = -1;
id special_pad_id = -1; id special_pad_id = -1;
id special_cls_id = -1;
id special_mask_id = -1;
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add. int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add. int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
@ -3869,7 +3875,9 @@ static void llm_load_hparams(
} }
// TODO: This should probably be in llama.h // TODO: This should probably be in llama.h
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false); static std::vector<llama_vocab::id> llama_tokenize_internal(
const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special = false
);
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
static void llm_load_vocab( static void llm_load_vocab(
@ -3891,23 +3899,27 @@ static void llm_load_vocab(
vocab.type = LLAMA_VOCAB_TYPE_NONE; vocab.type = LLAMA_VOCAB_TYPE_NONE;
// default special tokens // default special tokens
vocab.special_bos_id = -1; vocab.special_bos_id = -1;
vocab.special_eos_id = -1; vocab.special_eos_id = -1;
vocab.special_unk_id = -1; vocab.special_unk_id = -1;
vocab.special_sep_id = -1; vocab.special_sep_id = -1;
vocab.special_pad_id = -1; vocab.special_pad_id = -1;
vocab.linefeed_id = -1; vocab.special_cls_id = -1;
vocab.special_mask_id = -1;
vocab.linefeed_id = -1;
return; return;
} else if (tokenizer_name == "llama") { } else if (tokenizer_name == "llama") {
vocab.type = LLAMA_VOCAB_TYPE_SPM; vocab.type = LLAMA_VOCAB_TYPE_SPM;
// default special tokens // default special tokens
vocab.special_bos_id = 1; vocab.special_bos_id = 1;
vocab.special_eos_id = 2; vocab.special_eos_id = 2;
vocab.special_unk_id = 0; vocab.special_unk_id = 0;
vocab.special_sep_id = -1; vocab.special_sep_id = -1;
vocab.special_pad_id = -1; vocab.special_pad_id = -1;
vocab.special_cls_id = -1;
vocab.special_mask_id = -1;
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) { if (add_space_prefix_keyidx != -1) {
@ -3942,20 +3954,24 @@ static void llm_load_vocab(
} }
// default special tokens // default special tokens
vocab.special_bos_id = 11; vocab.special_bos_id = 11;
vocab.special_eos_id = 11; vocab.special_eos_id = 11;
vocab.special_unk_id = -1; vocab.special_unk_id = -1;
vocab.special_sep_id = -1; vocab.special_sep_id = -1;
vocab.special_pad_id = -1; vocab.special_pad_id = -1;
vocab.special_cls_id = -1;
vocab.special_mask_id = -1;
} else if (tokenizer_name == "bert") { } else if (tokenizer_name == "bert") {
vocab.type = LLAMA_VOCAB_TYPE_WPM; vocab.type = LLAMA_VOCAB_TYPE_WPM;
// default special tokens // default special tokens
vocab.special_bos_id = 101; vocab.special_bos_id = -1;
vocab.special_eos_id = 102; vocab.special_eos_id = -1;
vocab.special_unk_id = 100; vocab.special_unk_id = 100;
vocab.special_sep_id = -1; vocab.special_sep_id = 102;
vocab.special_pad_id = -1; vocab.special_pad_id = 0;
vocab.special_cls_id = 101;
vocab.special_mask_id = 103;
vocab.add_space_prefix = false; vocab.add_space_prefix = false;
} else { } else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
@ -4018,11 +4034,13 @@ static void llm_load_vocab(
// special tokens // special tokens
{ {
const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = { const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
{ LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
{ LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
{ LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
{ LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
}; };
for (const auto & it : special_token_types) { for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it)); const std::string & key = kv(std::get<0>(it));
@ -4214,12 +4232,14 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
// special tokens // special tokens
if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); }
if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
} }
// Returns false if cancelled by progress_callback // Returns false if cancelled by progress_callback
@ -10995,9 +11015,6 @@ struct llm_tokenizer_wpm {
output.push_back(vocab.special_unk_id); output.push_back(vocab.special_unk_id);
} }
} }
// append eos token
output.push_back(vocab.special_eos_id);
} }
std::vector<std::string> preprocess(const std::string & text) { std::vector<std::string> preprocess(const std::string & text) {
@ -11202,30 +11219,28 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
} }
} }
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) { static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
std::vector<llama_vocab::id> output; std::vector<llama_vocab::id> output;
// OG tokenizer behavior:
//
// tokenizer.encode('', add_bos=True) returns [1]
// tokenizer.encode('', add_bos=False) returns []
if (bos && vocab.special_bos_id != -1) {
output.push_back(vocab.special_bos_id);
}
if (raw_text.empty()) {
return output;
}
std::forward_list<fragment_buffer_variant> fragment_buffer; std::forward_list<fragment_buffer_variant> fragment_buffer;
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
if (special) tokenizer_st_partition(vocab, fragment_buffer); if (!raw_text.empty()) {
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
}
switch (vocab.type) { switch (vocab.type) {
case LLAMA_VOCAB_TYPE_SPM: case LLAMA_VOCAB_TYPE_SPM:
{ {
// OG tokenizer behavior:
//
// tokenizer.encode('', add_special_tokens=True) returns [1]
// tokenizer.encode('', add_special_tokens=False) returns []
if (add_special && vocab.special_add_bos == 1) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
}
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
// without adding this leading whitespace, we do not get the same results as the original tokenizer // without adding this leading whitespace, we do not get the same results as the original tokenizer
@ -11251,9 +11266,19 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
output.push_back(fragment.token); output.push_back(fragment.token);
} }
} }
if (add_special && vocab.special_add_eos == 1) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
}
} break; } break;
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
{ {
if (add_special && vocab.special_add_bos == 1) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
}
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@ -11267,9 +11292,16 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
output.push_back(fragment.token); output.push_back(fragment.token);
} }
} }
GGML_ASSERT(vocab.special_add_eos != 1);
} break; } break;
case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_WPM:
{ {
if (add_special) {
GGML_ASSERT(vocab.special_cls_id != -1);
output.push_back(vocab.special_cls_id);
}
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@ -11283,6 +11315,11 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
output.push_back(fragment.token); output.push_back(fragment.token);
} }
} }
if (add_special) {
GGML_ASSERT(vocab.special_sep_id != -1);
output.push_back(vocab.special_sep_id);
}
} break; } break;
case LLAMA_VOCAB_TYPE_NONE: case LLAMA_VOCAB_TYPE_NONE:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -15272,6 +15309,14 @@ llama_token llama_token_eos(const struct llama_model * model) {
return model->vocab.special_eos_id; return model->vocab.special_eos_id;
} }
llama_token llama_token_cls(const struct llama_model * model) {
return model->vocab.special_cls_id;
}
llama_token llama_token_sep(const struct llama_model * model) {
return model->vocab.special_sep_id;
}
llama_token llama_token_nl(const struct llama_model * model) { llama_token llama_token_nl(const struct llama_model * model) {
return model->vocab.linefeed_id; return model->vocab.linefeed_id;
} }
@ -15306,9 +15351,9 @@ int32_t llama_tokenize(
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,
int32_t n_tokens_max, int32_t n_tokens_max,
bool add_bos, bool add_special,
bool special) { bool parse_special) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special); auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
if (n_tokens_max < (int) res.size()) { if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);

10
llama.h
View File

@ -721,6 +721,8 @@ extern "C" {
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
// Returns -1 if unknown, 1 for true or 0 for false. // Returns -1 if unknown, 1 for true or 0 for false.
@ -743,16 +745,16 @@ extern "C" {
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
/// @return Returns the number of tokens on success, no more than n_tokens_max /// @return Returns the number of tokens on success, no more than n_tokens_max
/// @return Returns a negative number on failure - the number of tokens that would have been returned /// @return Returns a negative number on failure - the number of tokens that would have been returned
/// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
/// Does not insert a leading space. /// as plaintext. Does not insert a leading space.
LLAMA_API int32_t llama_tokenize( LLAMA_API int32_t llama_tokenize(
const struct llama_model * model, const struct llama_model * model,
const char * text, const char * text,
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,
int32_t n_tokens_max, int32_t n_tokens_max,
bool add_bos, bool add_special,
bool special); bool parse_special);
// Token Id -> Piece. // Token Id -> Piece.
// Uses the vocabulary in the provided context. // Uses the vocabulary in the provided context.