From 5f6e0c0dff1e7a89331e6b25eca9a9fd71324069 Mon Sep 17 00:00:00 2001 From: Marcus Dunn <51931484+MarcusDunn@users.noreply.github.com> Date: Tue, 5 Dec 2023 10:55:12 -1000 Subject: [PATCH] grammar : pre-computed pieces + reserve mem + less string copies (#4330) * reserve space for codepoints * improvement for the appended 0 * used precomputed token text for grammar sample * reserve canidates_decoded * reserve canidates_grammar * remove candidates_decoded * Revert "remove candidates_decoded" This reverts commit 3773328080e6a139ee83198329a13cf4ff61d707. * changed decode_utf8 to take src by ref --- llama.cpp | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/llama.cpp b/llama.cpp index b77020e10..14e5d312e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6851,14 +6851,13 @@ struct llama_grammar_candidate { // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. static std::pair, llama_partial_utf8> decode_utf8( - const char * src, - size_t n_src, + const std::string & src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; - const char * pos = src; + const char * pos = src.c_str(); std::vector code_points; // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. - code_points.reserve(n_src + 1); + code_points.reserve(src.size() + 1); uint32_t value = partial_start.value; int n_remain = partial_start.n_remain; @@ -6909,13 +6908,6 @@ static std::pair, llama_partial_utf8> decode_utf8( return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } -static std::pair, llama_partial_utf8> decode_utf8( - std::string src, - llama_partial_utf8 partial_start -) { - return decode_utf8(src.c_str(), src.size(), partial_start); -} - // returns true iff pos points to the end of one of the definitions of a rule static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { switch (pos->type) { @@ -7554,11 +7546,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c const llama_token eos = llama_token_eos(&ctx->model); std::vector, llama_partial_utf8>> candidates_decoded; + candidates_decoded.reserve(candidates->size); std::vector candidates_grammar; + candidates_grammar.reserve(candidates->size); for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_piece(ctx, id); + const std::string & piece = ctx->model.vocab.id_to_token[id].text; if (id == eos) { if (!allow_eos) { candidates->data[i].logit = -INFINITY; @@ -7770,7 +7764,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar GGML_ASSERT(false); } - const std::string piece = llama_token_to_piece(ctx, token); + const std::string & piece = ctx->model.vocab.id_to_token[token].text; // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8);