mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
grammars: 1.5x faster inference w/ complex grammars (vector reserves / reuses) (#6609)
* grammars: reserve rejects & next candidates * grammars: reuse new_stacks * grammars: fix missing sig change in llama.h * grammars: fix test (api changed) * grammars: update gbnf-validator.cpp * grammars: simpler syntax (no swap)
This commit is contained in:
parent
1bbdaf6ecd
commit
cbaadc9294
@ -17,7 +17,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
|
|||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
auto prev_stacks = grammar->stacks;
|
auto prev_stacks = grammar->stacks;
|
||||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
||||||
if (grammar->stacks.empty()) {
|
if (grammar->stacks.empty()) {
|
||||||
error_pos = pos;
|
error_pos = pos;
|
||||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
||||||
|
16
llama.cpp
16
llama.cpp
@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack(
|
|||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
// produces the N possible stacks if the given char is accepted at those
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
// positions
|
// positions
|
||||||
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
|
void llama_grammar_accept(
|
||||||
const std::vector<std::vector<llama_grammar_element>> & rules,
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
const uint32_t chr) {
|
const uint32_t chr,
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {
|
||||||
|
|
||||||
std::vector<std::vector<const llama_grammar_element *>> new_stacks;
|
new_stacks.clear();
|
||||||
|
|
||||||
for (const auto & stack : stacks) {
|
for (const auto & stack : stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
|
|||||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_stacks;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
|
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
|
||||||
@ -11951,6 +11950,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
|
|||||||
const std::vector<llama_grammar_candidate> & candidates) {
|
const std::vector<llama_grammar_candidate> & candidates) {
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> rejects;
|
std::vector<llama_grammar_candidate> rejects;
|
||||||
|
rejects.reserve(candidates.size());
|
||||||
|
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
for (const auto & tok : candidates) {
|
for (const auto & tok : candidates) {
|
||||||
@ -11964,6 +11964,8 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
|
|||||||
const llama_grammar_element * stack_pos = stack.back();
|
const llama_grammar_element * stack_pos = stack.back();
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> next_candidates;
|
std::vector<llama_grammar_candidate> next_candidates;
|
||||||
|
next_candidates.reserve(candidates.size());
|
||||||
|
|
||||||
for (const auto & tok : candidates) {
|
for (const auto & tok : candidates) {
|
||||||
if (*tok.code_points == 0) {
|
if (*tok.code_points == 0) {
|
||||||
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
||||||
@ -12771,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
|
|||||||
// Note terminating 0 in decoded string
|
// Note terminating 0 in decoded string
|
||||||
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
|
||||||
|
grammar->stacks = tmp_new_stacks;
|
||||||
}
|
}
|
||||||
grammar->partial_utf8 = decoded.second;
|
grammar->partial_utf8 = decoded.second;
|
||||||
GGML_ASSERT(!grammar->stacks.empty());
|
GGML_ASSERT(!grammar->stacks.empty());
|
||||||
|
5
llama.h
5
llama.h
@ -1097,10 +1097,11 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
|
|||||||
struct llama_context * ctx
|
struct llama_context * ctx
|
||||||
);
|
);
|
||||||
|
|
||||||
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
|
void llama_grammar_accept(
|
||||||
const std::vector<std::vector<llama_grammar_element>> & rules,
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
const uint32_t chr);
|
const uint32_t chr,
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
|
||||||
|
|
||||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
const std::string & src,
|
const std::string & src,
|
||||||
|
@ -38,7 +38,7 @@ number ::= [0-9]+)""";
|
|||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
auto prev_stacks = grammar->stacks;
|
auto prev_stacks = grammar->stacks;
|
||||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
||||||
assert(!grammar->stacks.empty());
|
assert(!grammar->stacks.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ ws ::= [ \t\n\r]?)""";
|
|||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
++pos;
|
++pos;
|
||||||
auto prev_stacks = grammar->stacks;
|
auto prev_stacks = grammar->stacks;
|
||||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
||||||
|
|
||||||
// Expect that each code point will not cause the grammar to fail
|
// Expect that each code point will not cause the grammar to fail
|
||||||
if (grammar->stacks.empty()) {
|
if (grammar->stacks.empty()) {
|
||||||
@ -173,7 +173,7 @@ ws ::= [ \t\n\r]?)""";
|
|||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
auto prev_stacks = grammar->stacks;
|
auto prev_stacks = grammar->stacks;
|
||||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
||||||
if (grammar->stacks.empty()) {
|
if (grammar->stacks.empty()) {
|
||||||
parse_failed = true;
|
parse_failed = true;
|
||||||
break;
|
break;
|
||||||
|
Loading…
Reference in New Issue
Block a user