mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 03:12:23 +01:00
Add left recursion check: quit early instead of going into an infinite loop (#7083)
* Add left recursion check: quit early instead of going into an infinite loop * Remove custom enum, rename left recursion check and move to "grammar internal" section, add handling for edge case where a leftmost nonterminal may be empty * Remove unnecessary declaration
This commit is contained in:
parent
27f65d6267
commit
e0f556186b
68
llama.cpp
68
llama.cpp
@ -13182,6 +13182,58 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
|
|||||||
return rejects;
|
return rejects;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool llama_grammar_detect_left_recursion(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
size_t rule_index,
|
||||||
|
std::vector<bool> * rules_visited,
|
||||||
|
std::vector<bool> * rules_in_progress,
|
||||||
|
std::vector<bool> * rules_may_be_empty) {
|
||||||
|
if ((*rules_in_progress)[rule_index]) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
(*rules_in_progress)[rule_index] = true;
|
||||||
|
|
||||||
|
const std::vector<llama_grammar_element> & rule = rules[rule_index];
|
||||||
|
|
||||||
|
// First check if the rule might produce the empty string. This could be done combined with the second
|
||||||
|
// step but it's more readable as two steps.
|
||||||
|
bool at_rule_start = true;
|
||||||
|
for (size_t i = 0; i < rule.size(); i++) {
|
||||||
|
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||||
|
if (at_rule_start) {
|
||||||
|
(*rules_may_be_empty)[rule_index] = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
at_rule_start = true;
|
||||||
|
} else {
|
||||||
|
at_rule_start = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
||||||
|
// be empty)
|
||||||
|
bool recurse_into_nonterminal = true;
|
||||||
|
for (size_t i = 0; i < rule.size(); i++) {
|
||||||
|
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
||||||
|
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
||||||
|
recurse_into_nonterminal = false;
|
||||||
|
}
|
||||||
|
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||||
|
recurse_into_nonterminal = true;
|
||||||
|
} else {
|
||||||
|
recurse_into_nonterminal = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(*rules_in_progress)[rule_index] = false;
|
||||||
|
(*rules_visited)[rule_index] = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// grammar - external
|
// grammar - external
|
||||||
//
|
//
|
||||||
@ -13201,6 +13253,19 @@ struct llama_grammar * llama_grammar_init(
|
|||||||
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for left recursion
|
||||||
|
std::vector<bool> rules_visited(n_rules);
|
||||||
|
std::vector<bool> rules_in_progress(n_rules);
|
||||||
|
std::vector<bool> rules_may_be_empty(n_rules);
|
||||||
|
for (size_t i = 0; i < n_rules; i++) {
|
||||||
|
if (rules_visited[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
||||||
|
throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// loop over alternates of start rule to build initial stacks
|
// loop over alternates of start rule to build initial stacks
|
||||||
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
||||||
pos = vec_rules[start_rule_index].data();
|
pos = vec_rules[start_rule_index].data();
|
||||||
@ -13223,6 +13288,9 @@ struct llama_grammar * llama_grammar_init(
|
|||||||
}
|
}
|
||||||
} while (true);
|
} while (true);
|
||||||
|
|
||||||
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,6 +28,19 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
|
|||||||
return grammar;
|
return grammar;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||||
|
fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
|
||||||
|
bool grammar_fails = false;
|
||||||
|
try {
|
||||||
|
build_grammar(grammar_str);
|
||||||
|
fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
grammar_fails = true;
|
||||||
|
fprintf(stdout, " ✅︎\n");
|
||||||
|
}
|
||||||
|
return grammar_fails;
|
||||||
|
}
|
||||||
|
|
||||||
static bool match_string(const std::string & input, llama_grammar* grammar) {
|
static bool match_string(const std::string & input, llama_grammar* grammar) {
|
||||||
auto decoded = decode_utf8(input, {});
|
auto decoded = decode_utf8(input, {});
|
||||||
|
|
||||||
@ -320,6 +333,38 @@ number ::= [0-9]+)""";
|
|||||||
fprintf(stderr, " ✅︎ Passed\n");
|
fprintf(stderr, " ✅︎ Passed\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_failure_left_recursion() {
|
||||||
|
fprintf(stderr, "⚫ Testing left recursion detection:\n");
|
||||||
|
|
||||||
|
// Test simple left recursion detection
|
||||||
|
const std::string simple_str = R"""(root ::= "a" | root "a")""";
|
||||||
|
assert(test_build_grammar_fails(simple_str));
|
||||||
|
|
||||||
|
// Test more complicated left recursion detection
|
||||||
|
const std::string medium_str = R"""(
|
||||||
|
root ::= asdf
|
||||||
|
asdf ::= "a" | asdf "a"
|
||||||
|
)""";
|
||||||
|
assert(test_build_grammar_fails(medium_str));
|
||||||
|
|
||||||
|
// Test even more complicated left recursion detection
|
||||||
|
const std::string hard_str = R"""(
|
||||||
|
root ::= asdf
|
||||||
|
asdf ::= "a" | foo "b"
|
||||||
|
foo ::= "c" | asdf "d" | "e")""";
|
||||||
|
assert(test_build_grammar_fails(hard_str));
|
||||||
|
|
||||||
|
// Test yet even more complicated left recursion detection
|
||||||
|
const std::string hardest_str = R"""(
|
||||||
|
root ::= asdf
|
||||||
|
asdf ::= "a" | foo "b"
|
||||||
|
foo ::= "c" | empty asdf "d" | "e"
|
||||||
|
empty ::= "blah" | )""";
|
||||||
|
assert(test_build_grammar_fails(hardest_str));
|
||||||
|
|
||||||
|
fprintf(stderr, " ✅︎ Passed\n");
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
fprintf(stdout, "Running grammar integration tests...\n");
|
fprintf(stdout, "Running grammar integration tests...\n");
|
||||||
test_simple_grammar();
|
test_simple_grammar();
|
||||||
@ -327,6 +372,7 @@ int main() {
|
|||||||
test_quantifiers();
|
test_quantifiers();
|
||||||
test_failure_missing_root();
|
test_failure_missing_root();
|
||||||
test_failure_missing_reference();
|
test_failure_missing_reference();
|
||||||
|
test_failure_left_recursion();
|
||||||
fprintf(stdout, "All tests passed.\n");
|
fprintf(stdout, "All tests passed.\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user