mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 22:38:58 +01:00
llama : handle added special tokens like HF does
Now the BERT tokenizer actually uses the SEP and CLS tokens from SpecialVocab.
This commit is contained in:
parent
748fc8baa3
commit
8803582721
@ -123,10 +123,10 @@ int main(int argc, char ** argv) {
|
|||||||
inputs.push_back(inp);
|
inputs.push_back(inp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// add eos if not present
|
// add SEP if not present
|
||||||
for (auto & inp : inputs) {
|
for (auto & inp : inputs) {
|
||||||
if (inp.empty() || inp.back() != llama_token_eos(model)) {
|
if (inp.empty() || inp.back() != llama_token_sep(model)) {
|
||||||
inp.push_back(llama_token_eos(model));
|
inp.push_back(llama_token_sep(model));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
101
llama.cpp
101
llama.cpp
@ -315,6 +315,8 @@ enum llm_kv {
|
|||||||
LLM_KV_TOKENIZER_UNK_ID,
|
LLM_KV_TOKENIZER_UNK_ID,
|
||||||
LLM_KV_TOKENIZER_SEP_ID,
|
LLM_KV_TOKENIZER_SEP_ID,
|
||||||
LLM_KV_TOKENIZER_PAD_ID,
|
LLM_KV_TOKENIZER_PAD_ID,
|
||||||
|
LLM_KV_TOKENIZER_CLS_ID,
|
||||||
|
LLM_KV_TOKENIZER_MASK_ID,
|
||||||
LLM_KV_TOKENIZER_ADD_BOS,
|
LLM_KV_TOKENIZER_ADD_BOS,
|
||||||
LLM_KV_TOKENIZER_ADD_EOS,
|
LLM_KV_TOKENIZER_ADD_EOS,
|
||||||
LLM_KV_TOKENIZER_ADD_PREFIX,
|
LLM_KV_TOKENIZER_ADD_PREFIX,
|
||||||
@ -384,6 +386,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
|
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
|
||||||
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
|
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
|
||||||
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
|
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
|
||||||
|
{ LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" },
|
||||||
|
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
||||||
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
||||||
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
||||||
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
||||||
@ -1986,6 +1990,8 @@ struct llama_vocab {
|
|||||||
id special_unk_id = 0;
|
id special_unk_id = 0;
|
||||||
id special_sep_id = -1;
|
id special_sep_id = -1;
|
||||||
id special_pad_id = -1;
|
id special_pad_id = -1;
|
||||||
|
id special_cls_id = -1;
|
||||||
|
id special_mask_id = -1;
|
||||||
|
|
||||||
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
|
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
|
||||||
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
|
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
|
||||||
@ -3869,7 +3875,9 @@ static void llm_load_hparams(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: This should probably be in llama.h
|
// TODO: This should probably be in llama.h
|
||||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false);
|
static std::vector<llama_vocab::id> llama_tokenize_internal(
|
||||||
|
const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special = false
|
||||||
|
);
|
||||||
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
|
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
|
||||||
|
|
||||||
static void llm_load_vocab(
|
static void llm_load_vocab(
|
||||||
@ -3896,6 +3904,8 @@ static void llm_load_vocab(
|
|||||||
vocab.special_unk_id = -1;
|
vocab.special_unk_id = -1;
|
||||||
vocab.special_sep_id = -1;
|
vocab.special_sep_id = -1;
|
||||||
vocab.special_pad_id = -1;
|
vocab.special_pad_id = -1;
|
||||||
|
vocab.special_cls_id = -1;
|
||||||
|
vocab.special_mask_id = -1;
|
||||||
vocab.linefeed_id = -1;
|
vocab.linefeed_id = -1;
|
||||||
|
|
||||||
return;
|
return;
|
||||||
@ -3908,6 +3918,8 @@ static void llm_load_vocab(
|
|||||||
vocab.special_unk_id = 0;
|
vocab.special_unk_id = 0;
|
||||||
vocab.special_sep_id = -1;
|
vocab.special_sep_id = -1;
|
||||||
vocab.special_pad_id = -1;
|
vocab.special_pad_id = -1;
|
||||||
|
vocab.special_cls_id = -1;
|
||||||
|
vocab.special_mask_id = -1;
|
||||||
|
|
||||||
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
||||||
if (add_space_prefix_keyidx != -1) {
|
if (add_space_prefix_keyidx != -1) {
|
||||||
@ -3947,15 +3959,19 @@ static void llm_load_vocab(
|
|||||||
vocab.special_unk_id = -1;
|
vocab.special_unk_id = -1;
|
||||||
vocab.special_sep_id = -1;
|
vocab.special_sep_id = -1;
|
||||||
vocab.special_pad_id = -1;
|
vocab.special_pad_id = -1;
|
||||||
|
vocab.special_cls_id = -1;
|
||||||
|
vocab.special_mask_id = -1;
|
||||||
} else if (tokenizer_name == "bert") {
|
} else if (tokenizer_name == "bert") {
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_WPM;
|
vocab.type = LLAMA_VOCAB_TYPE_WPM;
|
||||||
|
|
||||||
// default special tokens
|
// default special tokens
|
||||||
vocab.special_bos_id = 101;
|
vocab.special_bos_id = -1;
|
||||||
vocab.special_eos_id = 102;
|
vocab.special_eos_id = -1;
|
||||||
vocab.special_unk_id = 100;
|
vocab.special_unk_id = 100;
|
||||||
vocab.special_sep_id = -1;
|
vocab.special_sep_id = 102;
|
||||||
vocab.special_pad_id = -1;
|
vocab.special_pad_id = 0;
|
||||||
|
vocab.special_cls_id = 101;
|
||||||
|
vocab.special_mask_id = 103;
|
||||||
vocab.add_space_prefix = false;
|
vocab.add_space_prefix = false;
|
||||||
} else {
|
} else {
|
||||||
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||||
@ -4023,6 +4039,8 @@ static void llm_load_vocab(
|
|||||||
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
|
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
|
||||||
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
|
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
|
||||||
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
|
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
|
||||||
|
{ LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
|
||||||
|
{ LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
|
||||||
};
|
};
|
||||||
for (const auto & it : special_token_types) {
|
for (const auto & it : special_token_types) {
|
||||||
const std::string & key = kv(std::get<0>(it));
|
const std::string & key = kv(std::get<0>(it));
|
||||||
@ -4219,6 +4237,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||||||
if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
|
if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
|
||||||
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
|
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
|
||||||
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
|
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
|
||||||
|
if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); }
|
||||||
|
if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
|
||||||
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
|
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -10995,9 +11015,6 @@ struct llm_tokenizer_wpm {
|
|||||||
output.push_back(vocab.special_unk_id);
|
output.push_back(vocab.special_unk_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// append eos token
|
|
||||||
output.push_back(vocab.special_eos_id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> preprocess(const std::string & text) {
|
std::vector<std::string> preprocess(const std::string & text) {
|
||||||
@ -11202,30 +11219,28 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) {
|
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
|
||||||
std::vector<llama_vocab::id> output;
|
std::vector<llama_vocab::id> output;
|
||||||
|
|
||||||
// OG tokenizer behavior:
|
|
||||||
//
|
|
||||||
// tokenizer.encode('', add_bos=True) returns [1]
|
|
||||||
// tokenizer.encode('', add_bos=False) returns []
|
|
||||||
|
|
||||||
if (bos && vocab.special_bos_id != -1) {
|
|
||||||
output.push_back(vocab.special_bos_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (raw_text.empty()) {
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::forward_list<fragment_buffer_variant> fragment_buffer;
|
std::forward_list<fragment_buffer_variant> fragment_buffer;
|
||||||
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
|
|
||||||
|
|
||||||
if (special) tokenizer_st_partition(vocab, fragment_buffer);
|
if (!raw_text.empty()) {
|
||||||
|
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
|
||||||
|
if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
switch (vocab.type) {
|
switch (vocab.type) {
|
||||||
case LLAMA_VOCAB_TYPE_SPM:
|
case LLAMA_VOCAB_TYPE_SPM:
|
||||||
{
|
{
|
||||||
|
// OG tokenizer behavior:
|
||||||
|
//
|
||||||
|
// tokenizer.encode('', add_special_tokens=True) returns [1]
|
||||||
|
// tokenizer.encode('', add_special_tokens=False) returns []
|
||||||
|
|
||||||
|
if (add_special && vocab.special_add_bos == 1) {
|
||||||
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
|
output.push_back(vocab.special_bos_id);
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
// without adding this leading whitespace, we do not get the same results as the original tokenizer
|
// without adding this leading whitespace, we do not get the same results as the original tokenizer
|
||||||
@ -11251,9 +11266,19 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||||||
output.push_back(fragment.token);
|
output.push_back(fragment.token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (add_special && vocab.special_add_eos == 1) {
|
||||||
|
GGML_ASSERT(vocab.special_eos_id != -1);
|
||||||
|
output.push_back(vocab.special_eos_id);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_BPE:
|
case LLAMA_VOCAB_TYPE_BPE:
|
||||||
{
|
{
|
||||||
|
if (add_special && vocab.special_add_bos == 1) {
|
||||||
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
|
output.push_back(vocab.special_bos_id);
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
@ -11267,9 +11292,16 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||||||
output.push_back(fragment.token);
|
output.push_back(fragment.token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(vocab.special_add_eos != 1);
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_WPM:
|
case LLAMA_VOCAB_TYPE_WPM:
|
||||||
{
|
{
|
||||||
|
if (add_special) {
|
||||||
|
GGML_ASSERT(vocab.special_cls_id != -1);
|
||||||
|
output.push_back(vocab.special_cls_id);
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
@ -11283,6 +11315,11 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||||||
output.push_back(fragment.token);
|
output.push_back(fragment.token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (add_special) {
|
||||||
|
GGML_ASSERT(vocab.special_sep_id != -1);
|
||||||
|
output.push_back(vocab.special_sep_id);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_NONE:
|
case LLAMA_VOCAB_TYPE_NONE:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
@ -15272,6 +15309,14 @@ llama_token llama_token_eos(const struct llama_model * model) {
|
|||||||
return model->vocab.special_eos_id;
|
return model->vocab.special_eos_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token llama_token_cls(const struct llama_model * model) {
|
||||||
|
return model->vocab.special_cls_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_token_sep(const struct llama_model * model) {
|
||||||
|
return model->vocab.special_sep_id;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token llama_token_nl(const struct llama_model * model) {
|
llama_token llama_token_nl(const struct llama_model * model) {
|
||||||
return model->vocab.linefeed_id;
|
return model->vocab.linefeed_id;
|
||||||
}
|
}
|
||||||
@ -15306,9 +15351,9 @@ int32_t llama_tokenize(
|
|||||||
int32_t text_len,
|
int32_t text_len,
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens_max,
|
int32_t n_tokens_max,
|
||||||
bool add_bos,
|
bool add_special,
|
||||||
bool special) {
|
bool parse_special) {
|
||||||
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
|
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
|
||||||
|
|
||||||
if (n_tokens_max < (int) res.size()) {
|
if (n_tokens_max < (int) res.size()) {
|
||||||
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
||||||
|
10
llama.h
10
llama.h
@ -721,6 +721,8 @@ extern "C" {
|
|||||||
// Special tokens
|
// Special tokens
|
||||||
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
||||||
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
||||||
|
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
|
||||||
|
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
|
||||||
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
||||||
|
|
||||||
// Returns -1 if unknown, 1 for true or 0 for false.
|
// Returns -1 if unknown, 1 for true or 0 for false.
|
||||||
@ -743,16 +745,16 @@ extern "C" {
|
|||||||
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
||||||
/// @return Returns the number of tokens on success, no more than n_tokens_max
|
/// @return Returns the number of tokens on success, no more than n_tokens_max
|
||||||
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
||||||
/// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
|
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
||||||
/// Does not insert a leading space.
|
/// as plaintext. Does not insert a leading space.
|
||||||
LLAMA_API int32_t llama_tokenize(
|
LLAMA_API int32_t llama_tokenize(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const char * text,
|
const char * text,
|
||||||
int32_t text_len,
|
int32_t text_len,
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens_max,
|
int32_t n_tokens_max,
|
||||||
bool add_bos,
|
bool add_special,
|
||||||
bool special);
|
bool parse_special);
|
||||||
|
|
||||||
// Token Id -> Piece.
|
// Token Id -> Piece.
|
||||||
// Uses the vocabulary in the provided context.
|
// Uses the vocabulary in the provided context.
|
||||||
|
Loading…
Reference in New Issue
Block a user