#ifdef NDEBUG #undef NDEBUG #endif #include "llama.h" #include "llama-grammar.h" #include <cassert> #include <stdexcept> int main() { llama_grammar_parser parsed_grammar; std::vector<std::pair<std::string, uint32_t>> expected = { {"expr", 2}, {"expr_6", 6}, {"expr_7", 7}, {"ident", 8}, {"ident_10", 10}, {"num", 9}, {"num_11", 11}, {"root", 0}, {"root_1", 1}, {"root_5", 5}, {"term", 4}, {"ws", 3}, {"ws_12", 12}, }; std::vector<std::vector<llama_grammar_element>> expected_rules = { {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}}, { {LLAMA_GRETYPE_RULE_REF, 2}, {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_CHAR, 10}, {LLAMA_GRETYPE_END, 0}, }, {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}}, {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}}, { {LLAMA_GRETYPE_RULE_REF, 8}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 9}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_CHAR, 40}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_RULE_REF, 2}, {LLAMA_GRETYPE_CHAR, 41}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}, }, {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}}, { {LLAMA_GRETYPE_CHAR, 45}, {LLAMA_GRETYPE_CHAR_ALT, 43}, {LLAMA_GRETYPE_CHAR_ALT, 42}, {LLAMA_GRETYPE_CHAR_ALT, 47}, {LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_END, 0}, }, {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}}, { {LLAMA_GRETYPE_CHAR, 97}, {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, {LLAMA_GRETYPE_RULE_REF, 10}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}, }, {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}}, { {LLAMA_GRETYPE_CHAR, 97}, {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, {LLAMA_GRETYPE_CHAR_ALT, 48}, {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, {LLAMA_GRETYPE_CHAR_ALT, 95}, {LLAMA_GRETYPE_RULE_REF, 10}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, }, { {LLAMA_GRETYPE_CHAR, 48}, {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, {LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_CHAR, 48}, {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, {LLAMA_GRETYPE_END, 0}, }, { {LLAMA_GRETYPE_CHAR, 32}, {LLAMA_GRETYPE_CHAR_ALT, 9}, {LLAMA_GRETYPE_CHAR_ALT, 10}, {LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, }, }; for (auto pair : expected) { parsed_grammar.symbol_ids[pair.first] = pair.second; } for (auto rule : expected_rules) { parsed_grammar.rules.emplace_back(); for (auto element : rule) { parsed_grammar.rules.back().push_back(element); } } llama_grammar * grammar = NULL; std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); } std::vector<std::vector<llama_grammar_element>> expected_stacks = { { {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_CHAR, 97}, }, { {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_CHAR, 48}, }, { {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_CHAR, 48}, }, { {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_CHAR, 40}, }, { {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_CHAR, 97}, }, { {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_CHAR, 48}, }, { {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_CHAR, 48}, }, { {LLAMA_GRETYPE_CHAR, 61}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_CHAR, 40}, }}; auto index = 0; for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar)) { // compare stack to expected_stack for (uint32_t i = 0; i < stack.size(); i++) { const llama_grammar_element * element = stack[i]; const llama_grammar_element & expected_element = expected_stacks[index][i]; // pretty print error message before asserting if (expected_element.type != element->type || expected_element.value != element->value) { fprintf(stderr, "index: %d\n", index); fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value); fprintf(stderr, "expected_element != actual_element\n"); } assert(expected_element.type == element->type && expected_element.value == element->value); } index++; } std::vector<llama_grammar_candidate> next_candidates; next_candidates.resize(24); for (size_t i = 0; i < 24; ++i) { uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point cp[0] = 37 + i; cp[1] = 0; next_candidates[i] = {i, cp, {}}; } std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = { { {0, 37}, {1, 38}, {2, 39}, {3, 40}, {4, 41}, {5, 42}, {6, 43}, {7, 44}, {8, 45}, {9, 46}, {10, 47}, {11, 48}, {12, 49}, {13, 50}, {14, 51}, {15, 52}, {16, 53}, {17, 54}, {18, 55}, {19, 56}, {20, 57}, {21, 58}, {22, 59}, {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, {3, 40}, {4, 41}, {5, 42}, {6, 43}, {7, 44}, {8, 45}, {9, 46}, {10, 47}, {21, 58}, {22, 59}, {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, {3, 40}, {4, 41}, {5, 42}, {6, 43}, {7, 44}, {8, 45}, {9, 46}, {10, 47}, {21, 58}, {22, 59}, {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, {4, 41}, {5, 42}, {6, 43}, {7, 44}, {8, 45}, {9, 46}, {10, 47}, {11, 48}, {12, 49}, {13, 50}, {14, 51}, {15, 52}, {16, 53}, {17, 54}, {18, 55}, {19, 56}, {20, 57}, {21, 58}, {22, 59}, {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, {3, 40}, {4, 41}, {5, 42}, {6, 43}, {7, 44}, {8, 45}, {9, 46}, {10, 47}, {11, 48}, {12, 49}, {13, 50}, {14, 51}, {15, 52}, {16, 53}, {17, 54}, {18, 55}, {19, 56}, {20, 57}, {21, 58}, {22, 59}, {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, {3, 40}, {4, 41}, {5, 42}, {6, 43}, {7, 44}, {8, 45}, {9, 46}, {10, 47}, {21, 58}, {22, 59}, {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, {3, 40}, {4, 41}, {5, 42}, {6, 43}, {7, 44}, {8, 45}, {9, 46}, {10, 47}, {21, 58}, {22, 59}, {23, 60}, }, { {0, 37}, {1, 38}, {2, 39}, {4, 41}, {5, 42}, {6, 43}, {7, 44}, {8, 45}, {9, 46}, {10, 47}, {11, 48}, {12, 49}, {13, 50}, {14, 51}, {15, 52}, {16, 53}, {17, 54}, {18, 55}, {19, 56}, {20, 57}, {21, 58}, {22, 59}, {23, 60}, }, }; std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates); std::vector<std::vector<llama_grammar_candidate>> all_rejects; for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count) { rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates); all_rejects.push_back(rejects); } index = 0; for (auto rej : all_rejects) { for (uint32_t i = 0; i < rej.size(); i++) { auto element = rej[i]; auto expected_element = expected_reject[index][i]; assert(element.index == expected_element.first && *element.code_points == expected_element.second); } index++; } for (auto &candidate : next_candidates) { delete[] candidate.code_points; candidate.code_points = nullptr; } llama_grammar_free_impl(grammar); return 0; }