From 7494c7842812f58c65b68f82cb3aafacc645dc64 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 14 Aug 2023 21:33:33 +0300 Subject: [PATCH] llama : sync gguf-llama with llama (#2613) * llama : sync gguf-llama with llama * tests : fix build + warnings (test-tokenizer-1 still fails) * tests : fix wstring_convert * convert : fix layer names * llama : sync gguf-llama.cpp * convert : update HF converter to new tokenizer voodoo magics --- convert-llama-h5-to-gguf.py | 20 +- examples/gguf/gguf-llama-simple.cpp | 252 ++++++++++---------- gguf-llama.cpp | 350 ++++++++++++++++++++++++---- gguf-llama.h | 60 ++++- gguf_namemap.py | 86 +++---- llama.cpp | 59 +++-- tests/test-tokenizer-0.cpp | 30 +-- tests/test-tokenizer-1.cpp | 26 ++- 8 files changed, 590 insertions(+), 293 deletions(-) diff --git a/convert-llama-h5-to-gguf.py b/convert-llama-h5-to-gguf.py index 0bce659e6..9d91b433b 100644 --- a/convert-llama-h5-to-gguf.py +++ b/convert-llama-h5-to-gguf.py @@ -95,7 +95,7 @@ else: gguf_writer.add_architecture(llm_arch) gguf_writer.add_name(last_dir) -gguf_writer.add_file_type( "All tensors F32" if ftype == 0 else "Most tensors F16, some F32") +gguf_writer.add_file_type("All tensors F32" if ftype == 0 else "Most tensors F16, some F32") gguf_writer.add_source_hf_repo(hf_repo) gguf_writer.add_context_length(llm_arch, hparams["max_position_embeddings"]) gguf_writer.add_embedding_length(llm_arch, hparams["hidden_size"]) @@ -122,19 +122,11 @@ if Path(dir_model + "/tokenizer.model").is_file(): for i in range(tokenizer.vocab_size()): text: bytes - if tokenizer.is_unknown(i): - text = " \u2047 ".encode("utf-8") - elif tokenizer.is_control(i): - text = b"" - if tokenizer.is_byte(i): - piece = tokenizer.id_to_piece(i) - if len(piece) != 6: - raise Exception(f"Invalid token: {piece}") - byte_value = int(piece[3:-1], 16) - text = struct.pack("B", byte_value) - else: - text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - score: float = tokenizer.get_score(i) + score: float + + piece = tokenizer.id_to_piece(i) + text = piece.encode("utf-8") + score = tokenizer.get_score(i) tokens.append(text) scores.append(score) diff --git a/examples/gguf/gguf-llama-simple.cpp b/examples/gguf/gguf-llama-simple.cpp index fa8fc6fef..0679240d3 100644 --- a/examples/gguf/gguf-llama-simple.cpp +++ b/examples/gguf/gguf-llama-simple.cpp @@ -1,126 +1,126 @@ -#ifndef _GNU_SOURCE -#define _GNU_SOURCE -#endif - -#include "common.h" -#include "gguf-llama.h" -#include "build-info.h" - -#include -#include -#include -#include - -int main(int argc, char ** argv) { - gpt_params params; - - if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); - return 1 ; - } - - if (argc >= 2) { - params.model = argv[1]; - } - - if (argc >= 3) { - params.prompt = argv[2]; - } - - if (params.prompt.empty()) { - params.prompt = "Hello my name is"; - } - - // init LLM - - llama_backend_init(params.numa); - - llama_context_params ctx_params = llama_context_default_params(); - - llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); - - if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return 1; - } - - llama_context * ctx = llama_new_context_with_model(model, ctx_params); - - // tokenize the prompt - - std::vector tokens_list; - tokens_list = ::llama_tokenize(ctx, params.prompt, true); - - const int max_context_size = llama_n_ctx(ctx); - const int max_tokens_list_size = max_context_size - 4; - - if ((int)tokens_list.size() > max_tokens_list_size) { - fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size); - return 1; - } - - fprintf(stderr, "\n\n"); - - for (auto id : tokens_list) { - fprintf(stderr, "%s", llama_token_to_str(ctx, id)); - } - - fflush(stderr); - - // main loop - - // The LLM keeps a contextual cache memory of previous token evaluation. - // Usually, once this cache is full, it is required to recompute a compressed context based on previous - // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist - // example, we will just stop the loop once this cache is full or once an end of stream is detected. - - while (llama_get_kv_cache_token_count(ctx) < max_context_size) { - // evaluate the transformer - - if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } - - tokens_list.clear(); - - // sample the next token - - llama_token new_token_id = 0; - - auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - new_token_id = llama_sample_token_greedy(ctx , &candidates_p); - - // is it an end of stream ? - if (new_token_id == llama_token_eos()) { - fprintf(stderr, " [end of text]\n"); - break; - } - - // print the new token : - printf("%s", llama_token_to_str(ctx, new_token_id)); - fflush(stdout); - - // push this new token for next evaluation - tokens_list.push_back(new_token_id); - - } - - llama_free(ctx); - llama_free_model(model); - - llama_backend_free(); - - return 0; -} +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include "common.h" +#include "gguf-llama.h" +#include "build-info.h" + +#include +#include +#include +#include + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); + return 1 ; + } + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + params.prompt = argv[2]; + } + + if (params.prompt.empty()) { + params.prompt = "Hello my name is"; + } + + // init LLM + + llama_backend_init(params.numa); + + llama_context_params ctx_params = llama_context_default_params(); + + llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + // tokenize the prompt + + std::vector tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) tokens_list.size() > max_tokens_list_size) { + fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size); + return 1; + } + + fprintf(stderr, "\n\n"); + + for (auto id : tokens_list) { + fprintf(stderr, "%s", llama_token_to_str(ctx, id).c_str()); + } + + fflush(stderr); + + // main loop + + // The LLM keeps a contextual cache memory of previous token evaluation. + // Usually, once this cache is full, it is required to recompute a compressed context based on previous + // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist + // example, we will just stop the loop once this cache is full or once an end of stream is detected. + + while (llama_get_kv_cache_token_count(ctx) < max_context_size) { + // evaluate the transformer + + if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + + tokens_list.clear(); + + // sample the next token + + llama_token new_token_id = 0; + + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + new_token_id = llama_sample_token_greedy(ctx , &candidates_p); + + // is it an end of stream ? + if (new_token_id == llama_token_eos()) { + fprintf(stderr, " [end of text]\n"); + break; + } + + // print the new token : + printf("%s", llama_token_to_str(ctx, new_token_id).c_str()); + fflush(stdout); + + // push this new token for next evaluation + tokens_list.push_back(new_token_id); + + } + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/gguf-llama.cpp b/gguf-llama.cpp index 0f8eb3c90..8275c7357 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -7,6 +7,7 @@ #endif #include "gguf-util.h" +#define LLAMA_API_CPP // TODO: eliminate me #include "gguf-llama.h" #include "ggml.h" @@ -76,7 +77,7 @@ static std::string to_string(const T & val) { #define LLAMA_MAX_SCRATCH_BUFFERS 16 #endif -typedef void (*offload_func_t)(struct ggml_tensor * tensor); +#define UNUSED GGML_UNUSED #ifdef GGML_USE_CUBLAS #define llama_host_malloc(n) ggml_cuda_host_malloc(n) @@ -125,6 +126,8 @@ struct llama_buffer { } }; +typedef void (*offload_func_t)(struct ggml_tensor * tensor); + void llama_nop(struct ggml_tensor * tensor) { // don't offload by default (void) tensor; } @@ -623,7 +626,7 @@ struct gguf_file_loader { hparams.n_embd = read_u32("llama.embedding_length"); hparams.n_ff = read_u32("llama.feed_forward_length"); hparams.n_head = read_u32("llama.attention.head_count"); - hparams.n_layer = read_u32("llama.layer_count"); + hparams.n_layer = read_u32("llama.block_count"); hparams.n_rot = read_u32("llama.rope.dimension_count"); hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon"); @@ -1081,7 +1084,7 @@ static bool kv_cache_init( cache.ctx = ggml_init(params); if (!cache.ctx) { - fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); return false; } @@ -1370,7 +1373,7 @@ static void llama_model_load_internal( ml->ggml_ctx = ctx; - model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embeddings = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); // "output" tensor { @@ -1391,8 +1394,8 @@ static void llama_model_load_internal( backend_output = GGML_BACKEND_CPU; } - model.norm = ml->get_tensor("norm.weight", {n_embd}, backend_norm); - model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); + model.norm = ml->get_tensor("output_norm.weight", {n_embd}, backend_norm); + model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); if (backend_norm == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(model.norm); } @@ -1410,20 +1413,20 @@ static void llama_model_load_internal( auto & layer = model.layers[i]; - std::string layers_i = "layers." + std::to_string(i); + std::string layers_i = "blk." + std::to_string(i); - layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); + layer.attention_norm = ml->get_tensor(layers_i + ".attn_norm.weight", {n_embd}, backend); - layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); - layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split); - layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split); - layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); + layer.wq = ml->get_tensor(layers_i + ".attn_q.weight", {n_embd, n_embd}, backend_split); + layer.wk = ml->get_tensor(layers_i + ".attn_k.weight", {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml->get_tensor(layers_i + ".attn_v.weight", {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend_split); layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); - layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); + layer.w1 = ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, n_ff}, backend_split); + layer.w2 = ml->get_tensor(layers_i + ".ffn_down.weight", { n_ff, n_embd}, backend_split); + layer.w3 = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -2109,6 +2112,109 @@ static bool llama_eval_internal( // tokenizer // +static std::string llama_vocab_type(const llama_vocab& vocab) { + return vocab.token_to_id.size() == 32000 ? "spm": "bpe"; +} + +static bool llama_is_normal_token(const llama_vocab& vocab, llama_token token) { + if(llama_vocab_type(vocab) == "spm") + return token >= 259; + else if(llama_vocab_type(vocab) == "bpe") + return token >= 95; + else + return false; +} + +static bool llama_is_unknown_token(const llama_vocab& vocab, llama_token token) { + if(llama_vocab_type(vocab) == "spm") + return token == 0; + else + // TODO: improve? + return false; +} + +static bool llama_is_control_token(const llama_vocab& vocab, llama_token token) { + if(llama_vocab_type(vocab) == "spm") + return token == 1 || token == 2; + else + // TODO: improve? + return false; +} + +static bool llama_is_bos_token(const llama_vocab& vocab, llama_token token) { + if(llama_vocab_type(vocab) == "spm") + return token == 1; + else + // TODO: improve? + return false; +} + +static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) { + if(llama_vocab_type(vocab) == "spm") + return token == 2; + else + // TODO: improve? + return false; +} + +static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { + UNUSED(vocab); + UNUSED(token); + // TODO: improve? + return false; +} + +static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) { + UNUSED(vocab); + UNUSED(token); + // TODO: improve? + return false; +} + +static bool llama_is_byte_token(const llama_vocab& vocab, llama_token token) { + if(llama_vocab_type(vocab) == "spm") + return 3 <= token && token < 259; + else if(llama_vocab_type(vocab) == "bpe") + return 1 <= token && token < 95; + else + return false; +} + +static uint8_t llama_byte_to_char(const llama_vocab& vocab, uint8_t byte) { + if(llama_vocab_type(vocab) == "spm") + return byte + 3; + else if(llama_vocab_type(vocab) == "bpe") + return byte + 32; + else + return false; +} + +static std::string llama_escape_whitespace(const std::string& text) { + std::string result; + bool escaping = false; + result += "\xe2\x96\x81"; + for (size_t offs = 0; offs < text.length(); ++offs) { + if (text[offs] == ' ') { + if (!escaping) { + result += "\xe2\x96\x81"; + escaping = true; + } + } + else { + escaping = false; + result += text[offs]; + } + } + return result; +} + +static std::string llama_unescape_whitespace(const std::string& word) { + if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") { + return std::string(" ") + word.substr(3); + } + return word; +} + static size_t utf8_len(char src) { const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t highbits = static_cast(src) >> 4; @@ -2150,10 +2256,11 @@ struct llama_tokenizer { size_t offs = 0; while (offs < text.size()) { llama_sp_symbol sym; - size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); + size_t len = utf8_len(text[offs]); + GGML_ASSERT(offs + len <= text.size()); sym.text = text.c_str() + offs; - sym.n = char_len; - offs += char_len; + sym.n = len; + offs += len; sym.prev = index - 1; sym.next = offs == text.size() ? -1 : index + 1; index++; @@ -2198,23 +2305,36 @@ struct llama_tokenizer { for (int i = 0; i != -1; i = symbols_[i].next) { auto & symbol = symbols_[i]; - auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n)); - - if (token == vocab_.token_to_id.end()) { - // output any symbols that did not form tokens as bytes. - for (int j = 0; j < (int) symbol.n; ++j) { - // NOTE: old version, before #2420 - not sure what are the implications of this - //llama_vocab::id token_id = static_cast(symbol.text[j]) + 3; - llama_vocab::id token_id = vocab_.token_to_id.at(std::string(1, symbol.text[j])); - output.push_back(token_id); - } - } else { - output.push_back((*token).second); - } + resegment(symbol, output); } } private: + void resegment(llama_sp_symbol &symbol, std::vector &output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab_.token_to_id.find(text); + + // Do we need to support is_unused? + if (token != vocab_.token_to_id.end()) { + output.push_back((*token).second); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + for (int j = 0; j < (int)symbol.n; ++j) { + llama_vocab::id token_id = llama_byte_to_char(vocab_, symbol.text[j]); + output.push_back(token_id); + } + return; + } + + resegment(symbols_[p->second.first], output); + resegment(symbols_[p->second.second], output); + } + void try_add_bigram(int left, int right) { if (left == -1 || right == -1) { return; @@ -2239,18 +2359,22 @@ private: bigram.score = tok_score.score; bigram.size = text.size(); work_queue_.push(bigram); + + // Do we need to support is_unused? + rev_merge[text] = std::make_pair(left, right); } const llama_vocab & vocab_; std::vector symbols_; llama_sp_bigram::queue work_queue_; + std::map > rev_merge; }; -static std::vector llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { +static std::vector llama_tokenize(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) { llama_tokenizer tokenizer(vocab); std::vector output; - if (text.empty()) { + if (raw_text.empty()) { return output; } @@ -2258,6 +2382,13 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co output.push_back(llama_token_bos()); } + std::string text; + if (escape) { + text = llama_escape_whitespace(raw_text); + } else { + text = raw_text; + } + tokenizer.tokenize(text, output); return output; } @@ -2839,15 +2970,15 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const char * str = llama_token_to_str(ctx, id); + std::string str = llama_token_to_str(ctx, id); if (id == eos) { if (!allow_eos) { candidates->data[i].logit = -INFINITY; } - } else if (*str == 0) { + } else if (str.empty()) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(str)); + candidates_decoded.push_back(decode_utf8(str.c_str())); candidates_grammar.push_back({ i, candidates_decoded.back().data() }); } } @@ -3048,9 +3179,9 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar GGML_ASSERT(false); } - const char * str = llama_token_to_str(ctx, token); + std::string str = llama_token_to_str(ctx, token); // Note terminating 0 in decoded string - auto code_points = decode_utf8(str); + auto code_points = decode_utf8(str.c_str()); for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); } @@ -3507,6 +3638,12 @@ struct llama_context * llama_new_context_with_model( // this allocates all Metal resources and memory buffers ctx->ctx_metal = ggml_metal_init(1); + if (!ctx->ctx_metal) { + LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); + llama_free(ctx); + return NULL; + } + void * data_ptr = NULL; size_t data_size = 0; @@ -4281,7 +4418,8 @@ int llama_tokenize_with_model( llama_token * tokens, int n_max_tokens, bool add_bos) { - auto res = llama_tokenize(model->vocab, text, add_bos); + auto escape = llama_vocab_type(model->vocab) == "spm"; + auto res = llama_tokenize(model->vocab, text, add_bos, escape); if (n_max_tokens < (int) res.size()) { LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); @@ -4304,6 +4442,62 @@ int llama_tokenize( return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); } +std::vector llama_tokenize( + struct llama_context * ctx, + const std::string & text, + bool add_bos) { + int length = text.length() + add_bos; + std::vector result(length); + length = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + if (length < 0) { + result.resize(-length); + int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + assert(check == -length); + GGML_UNUSED(check); + } else { + result.resize(length); + } + return result; +} + +int llama_tokenize_bpe( + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos) { + auto res = llama_tokenize(ctx->model.vocab, text, add_bos, false); + + if (n_max_tokens < (int) res.size()) { + LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); + return -((int) res.size()); + } + + for (size_t i = 0; i < res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +std::vector llama_tokenize_bpe( + struct llama_context * ctx, + const std::string & text, + bool add_bos) { + int length = text.length() + add_bos; + std::vector result(length); + length = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos); + if (length < 0) { + result.resize(-length); + int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos); + assert(check == -length); + GGML_UNUSED(check); + } else { + result.resize(length); + } + return result; +} + int llama_n_vocab_from_model(const struct llama_model * model) { return model->vocab.id_to_token.size(); } @@ -4357,16 +4551,80 @@ float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } -const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) { - if (token >= llama_n_vocab_from_model(model)) { - return nullptr; +int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) { + if (0 <= token && token < llama_n_vocab_from_model(model)) { + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].tok; + if(llama_vocab_type(model->vocab) == "spm") { + result = llama_unescape_whitespace(result); + } + if (length < (int) result.length()) { + return -result.length(); + } + strncpy(str, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_unknown_token(model->vocab, token)) { + if (length < 3) { + return -3; + } + strncpy(str, "\xe2\x96\x85", 3); + return 3; + } else if (llama_is_control_token(model->vocab, token)) { + ; + } else if (llama_is_byte_token(model->vocab, token)) { + if (length < 1) { + return -1; + } + str[0] = llama_byte_to_char(model->vocab, token); + str[1] = 0x00; + return 1; + } } - - return model->vocab.id_to_token[token].tok.c_str(); + return 0; } -const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { - return llama_token_to_str_with_model(&ctx->model, token); +int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * str, int length) { + return llama_token_to_str_with_model(&ctx->model, token, str, length); +} + +std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int length = llama_token_to_str(ctx, token, result.data(), result.size()); + if (length < 0) { + result.resize(-length); + int check = llama_token_to_str(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -length); + } else { + result.resize(length); + } + + return std::string(result.data(), result.size()); +} + +int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) { + if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) { + std::string result = ctx->model.vocab.id_to_token[token].tok; + if (length < (int) result.length()) { + return -result.length(); + } + strncpy(str, result.c_str(), result.length()); + return result.length(); + } + return 0; +} + +std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int length = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); + if (length < 0) { + result.resize(-length); + const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -length); + } else { + result.resize(length); + } + + return std::string(result.data(), result.size()); } llama_token llama_token_bos() { diff --git a/gguf-llama.h b/gguf-llama.h index a8ed69d91..f342a534c 100644 --- a/gguf-llama.h +++ b/gguf-llama.h @@ -311,6 +311,13 @@ extern "C" { int n_max_tokens, bool add_bos); + LLAMA_API int llama_tokenize_bpe( + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos); + LLAMA_API int llama_tokenize_with_model( const struct llama_model * model, const char * text, @@ -352,14 +359,23 @@ extern "C" { LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Token Id -> String. Uses the vocabulary in the provided context - LLAMA_API const char * llama_token_to_str( + LLAMA_API int llama_token_to_str( const struct llama_context * ctx, - llama_token token); + llama_token token, + char * str, + int length); - LLAMA_API const char * llama_token_to_str_with_model( + LLAMA_API int llama_token_to_str_bpe( + const struct llama_context * ctx, + llama_token token, + char * str, + int length); + + LLAMA_API int llama_token_to_str_with_model( const struct llama_model * model, - llama_token token); - + llama_token token, + char * str, + int length); // Special tokens LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(); // end-of-sentence @@ -447,15 +463,43 @@ extern "C" { } #endif -// Internal API to be implemented by llama.cpp and used by tests/benchmarks only -#ifdef LLAMA_API_INTERNAL +// C++ API, will be moving to common.h soon (TM) +#ifdef LLAMA_API_CPP #include #include + +// +// Vocab utils +// + +std::vector llama_tokenize( + struct llama_context * ctx, + const std::string & text, + bool add_bos); + +std::vector llama_tokenize_bpe( + struct llama_context * ctx, + const std::string & text, + bool add_bos); + +std::string llama_token_to_str( + const struct llama_context * ctx, + llama_token token); + +std::string llama_token_to_str_bpe( + const struct llama_context * ctx, + llama_token token); + +// Internal API to be implemented by llama.cpp and used by tests/benchmarks only +#ifdef LLAMA_API_INTERNAL + struct ggml_tensor; const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx); -#endif +#endif // LLAMA_API_CPP + +#endif // LLAMA_API_INTERNAL #endif // LLAMA_H diff --git a/gguf_namemap.py b/gguf_namemap.py index 7546630ed..06cd0132d 100644 --- a/gguf_namemap.py +++ b/gguf_namemap.py @@ -4,92 +4,92 @@ def get_tensor_namemap( n_blocks : int): tensor_map = {} # Token embeddings mapped_to = "token_embd" - tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox - tensor_map["transformer.wte"] = mapped_to # gpt2 mpt + tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox + tensor_map["transformer.wte"] = mapped_to # gpt2 mpt tensor_map["transformer.word_embeddings"] = mapped_to # falcon - tensor_map["model.embed_tokens"] = mapped_to # llama-hf - tensor_map["tok_embeddings"] = mapped_to # llama-pth + tensor_map["model.embed_tokens"] = mapped_to # llama-hf + tensor_map["tok_embeddings"] = mapped_to # llama-pth # Position embeddings mapped_to = "pos_embd" tensor_map["transformer.wpe"] = mapped_to # gpt2 # Output norm mapped_to = "output_norm" tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox - tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon - tensor_map["transformer.norm_f"] = mapped_to # mpt - tensor_map["model.norm"] = mapped_to # llama-hf - tensor_map["norm"] = mapped_to # llama-pth + tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon + tensor_map["transformer.norm_f"] = mapped_to # mpt + tensor_map["model.norm"] = mapped_to # llama-hf + tensor_map["norm"] = mapped_to # llama-pth # Output mapped_to = "output" tensor_map["embed_out"] = mapped_to # gptneox - tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf - tensor_map["output"] = mapped_to # llama-pth + tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf + tensor_map["output"] = mapped_to # llama-pth # Attention and fee-forward layer blocks for i in range(0,n_blocks): # Attention norm mapped_to = "blk."+str(i)+".attn_norm" tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b - tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b - tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth + tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b + tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b + tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth # Attention norm 2 mapped_to = "blk."+str(i)+".attn_norm_2" tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b # Attention query-key-value mapped_to = "blk."+str(i)+".attn_qkv" - tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon + tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon # Attention query mapped_to = "blk."+str(i)+".attn_q" tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth + tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth # Attention key mapped_to = "blk."+str(i)+".attn_k" tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth + tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth # Attention value mapped_to = "blk."+str(i)+".attn_v" tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth + tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth # Attention output mapped_to = "blk."+str(i)+".attn_output" - tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt + tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth + tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth # Feed-forward norm mapped_to = "blk."+str(i)+".ffn_norm" tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt - tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth + tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt + tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth # Feed-forward up mapped_to = "blk."+str(i)+".ffn_up" tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth + tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth # Feed-forward gate mapped_to = "blk."+str(i)+".ffn_gate" tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth + tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth # Feed-forward down mapped_to = "blk."+str(i)+".ffn_down" tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox - tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 - tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt - tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon - tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf - tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth + tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth return tensor_map diff --git a/llama.cpp b/llama.cpp index 5504f1483..2ac907b96 100644 --- a/llama.cpp +++ b/llama.cpp @@ -72,6 +72,7 @@ static void llama_log_callback_default(llama_log_level level, const char * text, #define LLAMA_MAX_SCRATCH_BUFFERS 16 #endif +#define UNUSED GGML_UNUSED // available llama models enum e_model { @@ -1989,11 +1990,15 @@ static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) { } static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token token) { + UNUSED(vocab); + UNUSED(token); // TODO: improve? return false; } static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) { + UNUSED(vocab); + UNUSED(token); // TODO: improve? return false; } @@ -4399,21 +4404,21 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token if(llama_vocab_type(model->vocab) == "spm") { result = llama_unescape_whitespace(result); } - if(result.length() > length) { - return - result.length(); + if (length < (int) result.length()) { + return -result.length(); } - strcpy(str, result.c_str()); + strncpy(str, result.c_str(), result.length()); return result.length(); } else if (llama_is_unknown_token(model->vocab, token)) { - if(3 > length) { + if (length < 3) { return -3; } - strcpy(str, "\xe2\x96\x85"); + strncpy(str, "\xe2\x96\x85", 3); return 3; } else if (llama_is_control_token(model->vocab, token)) { ; } else if (llama_is_byte_token(model->vocab, token)) { - if(1 > length) { + if (length < 1) { return -1; } str[0] = llama_byte_to_char(model->vocab, token); @@ -4428,52 +4433,44 @@ int llama_token_to_str(const struct llama_context * ctx, llama_token token, char return llama_token_to_str_with_model(&ctx->model, token, str, length); } -std::string llama_token_to_str( - const struct llama_context * ctx, - llama_token token) { - std::string result; - int length = 8; - result.resize(length); - length = llama_token_to_str(ctx, token, (char *)result.data(), result.length()); +std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int length = llama_token_to_str(ctx, token, result.data(), result.size()); if (length < 0) { result.resize(-length); - int check = llama_token_to_str(ctx, token, (char *)result.data(), result.length()); - assert(check == -length); - GGML_UNUSED(check); + int check = llama_token_to_str(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -length); } else { result.resize(length); } - return result; + + return std::string(result.data(), result.size()); } int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) { if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) { std::string result = ctx->model.vocab.id_to_token[token].tok; - if (result.length() > length) { - return - result.length(); + if (length < (int) result.length()) { + return -result.length(); } - strcpy(str, result.c_str()); + strncpy(str, result.c_str(), result.length()); return result.length(); } return 0; } -std::string llama_token_to_str_bpe( - const struct llama_context * ctx, - llama_token token) { - std::string result; - int length = 8; - result.resize(length); - length = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length()); +std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int length = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); if (length < 0) { result.resize(-length); - int check = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length()); - assert(check == -length); - GGML_UNUSED(check); + const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -length); } else { result.resize(length); } - return result; + + return std::string(result.data(), result.size()); } llama_token llama_token_bos() { diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index a523c320c..f973271a3 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -8,14 +8,13 @@ static std::string unescape_whitespace(llama_context* ctx, const std::vector& tokens) { std::string result; - for (int i = 0; i < tokens.size(); ++i) { + for (size_t i = 0; i < tokens.size(); ++i) { result += llama_token_to_str(ctx, tokens[i]); } return result; } -static const std::map> & k_tests() -{ +static const std::map> & k_tests() { static std::map> _k_tests = { { " ", {1, 259, }, }, { "\t", { 1, 29871, 12, }, }, @@ -29,17 +28,18 @@ static const std::map> & k_tests() { " this is πŸ¦™.cpp", { 1, 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, { "w048 7tuijk dsdfhu", { 1, 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ", { 1, 1538, 4851, 665, 1386, 29713, 1305, }, }, - { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161, - 146, 228, 162, 133, 228, 161, 153, 228, 161, 186, - 31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228, - 161, 136, 228, 161, 132, 228, 161, 158, 228, 161, - 136, 228, 162, 132, 228, 161, 140, }, }, + { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161, + 146, 228, 162, 133, 228, 161, 153, 228, 161, 186, + 31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228, + 161, 136, 228, 161, 132, 228, 161, 158, 228, 161, + 136, 228, 162, 132, 228, 161, 140, }, }, { "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)", - { 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871, - 243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598, - 313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681, - 313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, }, - }; + { 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871, + 243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598, + 313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681, + 313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, }, + }; + return _k_tests; }; @@ -90,8 +90,8 @@ int main(int argc, char **argv) { } for (const auto & test_kv : k_tests()) { - std::vector res = llama_tokenize(ctx, test_kv.first.c_str(), true); - fprintf(stderr, "%s : '%s' tokenized to '%s'\n", + std::vector res = llama_tokenize(ctx, test_kv.first, true); + fprintf(stderr, "%s : '%s' tokenized to '%s'\n", __func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str()); bool correct = res.size() == test_kv.second.size(); diff --git a/tests/test-tokenizer-1.cpp b/tests/test-tokenizer-1.cpp index 122e51684..a44f38cd7 100644 --- a/tests/test-tokenizer-1.cpp +++ b/tests/test-tokenizer-1.cpp @@ -8,8 +8,9 @@ #include #include #include +#include -static std::string vocab_type(llama_context* ctx) { +static std::string vocab_type(llama_context * ctx) { return llama_n_vocab(ctx) == 32000 ? "spm": "bpe"; } @@ -32,9 +33,9 @@ static std::string escape_whitespace(const std::string& text) { return result; } -static std::string unescape_whitespace(llama_context* ctx, const std::vector& tokens) { +static std::string unescape_whitespace(llama_context * ctx, const std::vector & tokens) { std::string result; - for (int i = 0; i < tokens.size(); ++i) { + for (size_t i = 0; i < tokens.size(); ++i) { result += llama_token_to_str(ctx, tokens[i]); } return result; @@ -85,17 +86,17 @@ int main(int argc, char **argv) { if (tokens.size() == 1) { if (i != tokens[0]) { std::string backward = llama_token_to_str(ctx, tokens[0]); - fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n", + fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n", __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str()); return 2; } } else { - if ((vocab_type(ctx) == "spm" && i <= 258) || + if ((vocab_type(ctx) == "spm" && i <= 258) || (vocab_type(ctx) == "bpe" && (i == 0 || i >= 100000))) { - fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n", + fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n", __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); } else { - fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n", + fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n", __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); return 2; } @@ -105,10 +106,15 @@ int main(int argc, char **argv) { std::wstring_convert, wchar_t> converter; for (wchar_t ch = 0x0000; ch < 0xffff; ++ch) { std::wstring wstr(1, ch); - std::string str = converter.to_bytes(wstr); - std::vector tokens = llama_tokenize(ctx, escape_whitespace(str).c_str(), false); + std::string str; + try { + str = converter.to_bytes(wstr); + } catch (std::exception & e) { + continue; + } + std::vector tokens = llama_tokenize(ctx, escape_whitespace(str), false); if (tokens.size() == 1) { - fprintf(stderr, "%s : info: %s tokenized to %d \n", + fprintf(stderr, "%s : info: %s tokenized to %d \n", __func__, str.c_str(), tokens[0]); } }