diff --git a/llama.cpp b/llama.cpp index e7412de4b..2493802ab 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1702,12 +1702,13 @@ struct llama_mlock { }; using llama_mlocks = std::vector>; -static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { +// NOTE: avoid ever using this except for building the token_to_piece caches +static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); + const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); + int check = llama_token_to_piece(model, token, result.data(), result.size(), special); GGML_ASSERT(check == -n_tokens); } else { @@ -2162,7 +2163,11 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; - std::vector special_tokens_cache; + bool has_cache = false; + + std::vector cache_special_tokens; + std::unordered_map cache_token_to_piece; // llama_token_to_piece(special = false); + std::unordered_map cache_token_to_piece_special; // llama_token_to_piece(special = true); std::map, int> bpe_ranks; @@ -4833,18 +4838,26 @@ static void llm_load_vocab( { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { - vocab.special_tokens_cache.push_back(id); + vocab.cache_special_tokens.push_back(id); } } - std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(), + std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(), [&] (const llama_vocab::id a, const llama_vocab::id b) { return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size(); } ); - LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size()); + LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.cache_special_tokens.size()); } + + // build token to piece caches + for (llama_token id = 0; id < (llama_token) n_vocab; ++id) { + vocab.cache_token_to_piece[id] = llama_token_to_piece(&model, id, false); + vocab.cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true); + } + + vocab.has_cache = true; } static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { @@ -13233,7 +13246,7 @@ struct fragment_buffer_variant { static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { // for each special token - for (const llama_vocab::id special_id : vocab.special_tokens_cache) { + for (const llama_vocab::id special_id : vocab.cache_special_tokens) { const auto & special_token = vocab.id_to_token[special_id].text; // for each text fragment @@ -14392,7 +14405,7 @@ void llama_sample_repetition_penalties( void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { GGML_ASSERT(ctx); - const int64_t t_start_sample_us = ggml_time_us(); + int64_t t_start_sample_us = ggml_time_us(); bool allow_eog = false; for (const auto & stack : grammar->stacks) { @@ -14408,8 +14421,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c candidates_grammar.reserve(candidates->size); for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_piece(ctx, id, false); + const llama_token id = candidates->data[i].id; + const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id); if (llama_token_is_eog(&ctx->model, id)) { if (!allow_eog) { @@ -14609,7 +14622,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar GGML_ASSERT(false); } - const std::string piece = llama_token_to_piece(ctx, token, false); + const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8); @@ -18292,69 +18305,79 @@ static std::string llama_decode_text(const std::string & text) { // does not write null-terminator to buf int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) { + if (model->vocab.has_cache) { + const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece; + const auto & res = cache.at(token); + if (length < (int) res.size()) { + return -(int) res.size(); + } + memcpy(buf, res.c_str(), res.size()); + return res.size(); + } + if (0 <= token && token < llama_n_vocab(model)) { switch (llama_vocab_get_type(model->vocab)) { - case LLAMA_VOCAB_TYPE_WPM: - case LLAMA_VOCAB_TYPE_SPM: { - // NOTE: we accept all unsupported token types, - // suppressing them like CONTROL tokens. - if (llama_is_normal_token(model->vocab, token)) { - std::string result = model->vocab.id_to_token[token].text; - llama_unescape_whitespace(result); - if (length < (int) result.length()) { - return -(int) result.length(); + case LLAMA_VOCAB_TYPE_WPM: + case LLAMA_VOCAB_TYPE_SPM: { + // NOTE: we accept all unsupported token types, + // suppressing them like CONTROL tokens. + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + llama_unescape_whitespace(result); + if (length < (int) result.length()) { + return -(int) result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if ( + (llama_is_user_defined_token(model->vocab, token)) || + (llama_is_control_token (model->vocab, token) && special)) { + std::string result = model->vocab.id_to_token[token].text; + if (length < (int) result.length()) { + return -(int) result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT + if (length < 3) { + return -3; + } + memcpy(buf, "\xe2\x96\x85", 3); + return 3; + } else if (llama_is_byte_token(model->vocab, token)) { + if (length < 1) { + return -1; + } + buf[0] = llama_token_to_byte(model->vocab, token); + return 1; } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if ( - (llama_is_user_defined_token(model->vocab, token)) || - (llama_is_control_token (model->vocab, token) && special)) { - std::string result = model->vocab.id_to_token[token].text; - if (length < (int) result.length()) { - return -(int) result.length(); - } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT - if (length < 3) { - return -3; - } - memcpy(buf, "\xe2\x96\x85", 3); - return 3; - } else if (llama_is_byte_token(model->vocab, token)) { - if (length < 1) { - return -1; - } - buf[0] = llama_token_to_byte(model->vocab, token); - return 1; + break; } - break; - } - case LLAMA_VOCAB_TYPE_BPE: { - // NOTE: we accept all unsupported token types, - // suppressing them like CONTROL tokens. - if (llama_is_normal_token(model->vocab, token)) { - std::string result = model->vocab.id_to_token[token].text; - result = llama_decode_text(result); - if (length < (int) result.length()) { - return -(int) result.length(); + case LLAMA_VOCAB_TYPE_BPE: { + // NOTE: we accept all unsupported token types, + // suppressing them like CONTROL tokens. + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + result = llama_decode_text(result); + if (length < (int) result.length()) { + return -(int) result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if ( + (llama_is_user_defined_token(model->vocab, token)) || + (llama_is_control_token (model->vocab, token) && special)) { + std::string result = model->vocab.id_to_token[token].text; + if (length < (int) result.length()) { + return -(int) result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if ( - (llama_is_user_defined_token(model->vocab, token)) || - (llama_is_control_token (model->vocab, token) && special)) { - std::string result = model->vocab.id_to_token[token].text; - if (length < (int) result.length()) { - return -(int) result.length(); - } - memcpy(buf, result.c_str(), result.length()); - return result.length(); + break; } - break; - } - default: - GGML_ASSERT(false); + default: + GGML_ASSERT(false); } } return 0; diff --git a/llama.h b/llama.h index 3e4474bb9..95105c28e 100644 --- a/llama.h +++ b/llama.h @@ -424,8 +424,8 @@ extern "C" { LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);