From 6c63550f63657cd00c5e6166c533ff190d1e0f3a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 14 Aug 2023 22:10:19 +0300 Subject: [PATCH] llama : update tokenizer style --- gguf-llama.cpp | 87 +++++++++++++++++++++++++++--------------------- llama.cpp | 89 +++++++++++++++++++++++++++++--------------------- 2 files changed, 101 insertions(+), 75 deletions(-) diff --git a/gguf-llama.cpp b/gguf-llama.cpp index 8275c7357..1c1d6718e 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -2112,49 +2112,56 @@ static bool llama_eval_internal( // tokenizer // -static std::string llama_vocab_type(const llama_vocab& vocab) { +static std::string llama_vocab_type(const llama_vocab & vocab) { return vocab.token_to_id.size() == 32000 ? "spm": "bpe"; } -static bool llama_is_normal_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token >= 259; - else if(llama_vocab_type(vocab) == "bpe") + } + + if (llama_vocab_type(vocab) == "bpe") { return token >= 95; - else - return false; + } + + return false; } -static bool llama_is_unknown_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token == 0; - else - // TODO: improve? - return false; + } + + // TODO: improve? + return false; } -static bool llama_is_control_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token == 1 || token == 2; - else - // TODO: improve? - return false; + } + + // TODO: improve? + return false; } -static bool llama_is_bos_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token == 1; - else - // TODO: improve? - return false; + } + + // TODO: improve? + return false; } -static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token == 2; - else - // TODO: improve? - return false; + } + + // TODO: improve? + return false; } static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { @@ -2164,29 +2171,35 @@ static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token t return false; } -static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) { +static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) { UNUSED(vocab); UNUSED(token); // TODO: improve? return false; } -static bool llama_is_byte_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return 3 <= token && token < 259; - else if(llama_vocab_type(vocab) == "bpe") + } + + if (llama_vocab_type(vocab) == "bpe") { return 1 <= token && token < 95; - else - return false; + } + + return false; } -static uint8_t llama_byte_to_char(const llama_vocab& vocab, uint8_t byte) { - if(llama_vocab_type(vocab) == "spm") +static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) { + if (llama_vocab_type(vocab) == "spm") { return byte + 3; - else if(llama_vocab_type(vocab) == "bpe") + } + + if (llama_vocab_type(vocab) == "bpe") { return byte + 32; - else - return false; + } + + return false; } static std::string llama_escape_whitespace(const std::string& text) { diff --git a/llama.cpp b/llama.cpp index 2ac907b96..fb2207f84 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1944,81 +1944,94 @@ static bool llama_eval_internal( // tokenizer // -static std::string llama_vocab_type(const llama_vocab& vocab) { +static std::string llama_vocab_type(const llama_vocab & vocab) { return vocab.token_to_id.size() == 32000 ? "spm": "bpe"; } -static bool llama_is_normal_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token >= 259; - else if(llama_vocab_type(vocab) == "bpe") + } + + if (llama_vocab_type(vocab) == "bpe") { return token >= 95; - else - return false; + } + + return false; } -static bool llama_is_unknown_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token == 0; - else - // TODO: improve? - return false; + } + + // TODO: improve? + return false; } -static bool llama_is_control_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token == 1 || token == 2; - else - // TODO: improve? - return false; + } + + // TODO: improve? + return false; } -static bool llama_is_bos_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token == 1; - else - // TODO: improve? - return false; + } + + // TODO: improve? + return false; } -static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return token == 2; - else - // TODO: improve? - return false; + } + + // TODO: improve? + return false; } -static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token token) { +static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { UNUSED(vocab); UNUSED(token); // TODO: improve? return false; } -static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) { +static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) { UNUSED(vocab); UNUSED(token); // TODO: improve? return false; } -static bool llama_is_byte_token(const llama_vocab& vocab, llama_token token) { - if(llama_vocab_type(vocab) == "spm") +static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) { + if (llama_vocab_type(vocab) == "spm") { return 3 <= token && token < 259; - else if(llama_vocab_type(vocab) == "bpe") + } + + if (llama_vocab_type(vocab) == "bpe") { return 1 <= token && token < 95; - else - return false; + } + + return false; } -static uint8_t llama_byte_to_char(const llama_vocab& vocab, uint8_t byte) { - if(llama_vocab_type(vocab) == "spm") +static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) { + if (llama_vocab_type(vocab) == "spm") { return byte + 3; - else if(llama_vocab_type(vocab) == "bpe") + } + + if (llama_vocab_type(vocab) == "bpe") { return byte + 32; - else - return false; + } + + return false; } static std::string llama_escape_whitespace(const std::string& text) {