diff --git a/Makefile b/Makefile index bdd5ef335..11b31c5c8 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ TEST_TARGETS = \ tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \ tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \ tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \ - tests/test-json-schema-to-grammar + tests/test-json-schema-to-grammar tests/test-grammar-integration # Code coverage output files COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report @@ -918,6 +918,10 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar- $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o $(OBJS) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a43439aed..b5d7bb59c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -59,6 +59,7 @@ llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt2 AR llama_test(test-grammar-parser.cpp) llama_test(test-llama-grammar.cpp) +llama_test(test-grammar-integration.cpp) llama_test(test-grad0.cpp) # llama_test(test-opt.cpp) # SLOW llama_test(test-backend-ops.cpp) diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp new file mode 100644 index 000000000..0a9c3b6f5 --- /dev/null +++ b/tests/test-grammar-integration.cpp @@ -0,0 +1,243 @@ +#ifdef NDEBUG +#undef NDEBUG +#endif + +#define LLAMA_API_INTERNAL + +#include "ggml.h" +#include "llama.h" +#include "grammar-parser.h" +#include "unicode.h" +#include +#include + +static void test_simple_grammar() { + // Test case for a simple grammar + const std::string grammar_str = R"""(root ::= expr +expr ::= term ("+" term)* +term ::= number +number ::= [0-9]+)"""; + + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + + // Ensure we parsed correctly + assert(!parsed_grammar.rules.empty()); + + // Ensure we have a root node + assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); + + std::vector grammar_rules(parsed_grammar.c_rules()); + llama_grammar* grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + + std::string input = "123+456"; + + auto decoded = decode_utf8(input, {}); + + const auto & code_points = decoded.first; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + auto prev_stacks = grammar->stacks; + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + assert(!grammar->stacks.empty()); + } + + bool completed_grammar = false; + + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + completed_grammar = true; + break; + } + } + + assert(completed_grammar); + + // Clean up allocated memory + llama_grammar_free(grammar); +} + +static void test_complex_grammar() { + // Test case for a more complex grammar, with both failure strings and success strings + const std::string grammar_str = R"""(root ::= expression +expression ::= term ws (("+"|"-") ws term)* +term ::= factor ws (("*"|"/") ws factor)* +factor ::= number | variable | "(" expression ")" | function-call +number ::= [0-9]+ +variable ::= [a-zA-Z_][a-zA-Z0-9_]* +function-call ::= variable ws "(" (expression ("," ws expression)*)? ")" +ws ::= [ \t\n\r]?)"""; + + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + + // Ensure we parsed correctly + assert(!parsed_grammar.rules.empty()); + + // Ensure we have a root node + assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); + + std::vector grammar_rules(parsed_grammar.c_rules()); + llama_grammar* grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + + // Save the original grammar stacks so that we can reset after every new string we want to test + auto original_stacks = grammar->stacks; + + // Test a few strings + std::vector test_strings_pass = { + "42", + "1*2*3*4*5", + "x", + "x+10", + "x1+y2", + "(a+b)*(c-d)", + "func()", + "func(x,y+2)", + "a*(b+c)-d/e", + "f(g(x),h(y,z))", + "x + 10", + "x1 + y2", + "(a + b) * (c - d)", + "func()", + "func(x, y + 2)", + "a * (b + c) - d / e", + "f(g(x), h(y, z))", + "123+456", + "123*456*789-123/456+789*123", + "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" + }; + + std::vector test_strings_fail = { + "+", + "/ 3x", + "x + + y", + "a * / b", + "func(,)", + "func(x y)", + "(a + b", + "x + y)", + "a + b * (c - d", + "42 +", + "x +", + "x + 10 +", + "(a + b) * (c - d", + "func(", + "func(x, y + 2", + "a * (b + c) - d /", + "f(g(x), h(y, z)", + "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", + }; + + // Passing strings + for (const auto & test_string : test_strings_pass) { + auto decoded = decode_utf8(test_string, {}); + + const auto & code_points = decoded.first; + + int pos = 0; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + ++pos; + auto prev_stacks = grammar->stacks; + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + + // Expect that each code point will not cause the grammar to fail + if (grammar->stacks.empty()) { + fprintf(stdout, "Error at position %d\n", pos); + fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str()); + fprintf(stderr, "Input string is %s:\n", test_string.c_str()); + } + assert(!grammar->stacks.empty()); + } + + bool completed_grammar = false; + + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + completed_grammar = true; + break; + } + } + + assert(completed_grammar); + + // Reset the grammar stacks + grammar->stacks = original_stacks; + } + + // Failing strings + for (const auto & test_string : test_strings_fail) { + auto decoded = decode_utf8(test_string, {}); + + const auto & code_points = decoded.first; + bool parse_failed = false; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + auto prev_stacks = grammar->stacks; + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + if (grammar->stacks.empty()) { + parse_failed = true; + break; + } + assert(!grammar->stacks.empty()); + } + + bool completed_grammar = false; + + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + completed_grammar = true; + break; + } + } + + // Ensure that the grammar is not completed, or that each string failed to match as-expected + assert((!completed_grammar) || parse_failed); + + // Reset the grammar stacks + grammar->stacks = original_stacks; + } + + // Clean up allocated memory + llama_grammar_free(grammar); +} + +static void test_failure_missing_root() { + // Test case for a grammar that is missing a root rule + const std::string grammar_str = R"""(rot ::= expr +expr ::= term ("+" term)* +term ::= number +number ::= [0-9]+)"""; + + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + + // Ensure we parsed correctly + assert(!parsed_grammar.rules.empty()); + + // Ensure we do NOT have a root node + assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()); +} + +static void test_failure_missing_reference() { + // Test case for a grammar that is missing a referenced rule + const std::string grammar_str = R"""(root ::= expr +expr ::= term ("+" term)* +term ::= numero +number ::= [0-9]+)"""; + + fprintf(stderr, "Expected error: "); + + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + + // Ensure we did NOT parsed correctly + assert(parsed_grammar.rules.empty()); + + fprintf(stderr, "End of expected error. Test successful.\n"); +} + +int main() { + test_simple_grammar(); + test_complex_grammar(); + test_failure_missing_root(); + test_failure_missing_reference(); + return 0; +}