From 463173a6c0ff353055eb90665794884c888c790f Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Sun, 27 Aug 2023 16:50:33 +0300 Subject: [PATCH] llama : speedup tokenization (#2831) * Speedup tokenization On current master it takes ~3.2 seconds to tokenize Wikitext. With this change it becomes ~525 ms. * Fixit: it was missing the piece after the last found occurence --------- Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 4 ++++ llama.cpp | 15 ++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index b596d0626..ebafa0c29 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -190,10 +190,14 @@ void perplexity(llama_context * ctx, const gpt_params & params) { const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = is_spm; + auto tim1 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + auto tim2 = std::chrono::high_resolution_clock::now(); + fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); + const int n_chunk_max = tokens.size() / params.n_ctx; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); diff --git a/llama.cpp b/llama.cpp index 0d12d9cca..0bb8fcd6e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -114,12 +114,17 @@ static size_t utf8_len(char src) { } void replace_all(std::string & s, const std::string & search, const std::string & replace) { - for (size_t pos = 0; ; pos += replace.length()) { - pos = s.find(search, pos); - if (pos == std::string::npos) break; - s.erase(pos, search.length()); - s.insert(pos, replace); + std::string result; + for (size_t pos = 0; ; pos += search.length()) { + auto new_pos = s.find(search, pos); + if (new_pos == std::string::npos) { + result += s.substr(pos, s.size() - pos); + break; + } + result += s.substr(pos, new_pos - pos) + replace; + pos = new_pos; } + s = std::move(result); } static void zeros(std::ofstream & file, size_t n) {