From 4e3880978f8b1bf546dd4e6f3b524d6b8739c49c Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Fri, 10 May 2024 07:01:08 -0400 Subject: [PATCH] Fix memory bug in grammar parser (#7194) The llama.cpp grammar parser had a bug where forgetting to add a closing quotation mark to strings would cause parsing to crash. Anyone running a server on a public endpoint is advised to upgrade. To reproduce this bug ./llamafile -m foo.gguf -p bar --grammar 'root::="' Credit for discovering and reporting this issue goes to Eclypsium Security Researcher Richard Johnson . --- common/common.cpp | 8 +++----- common/grammar-parser.cpp | 9 +++++++++ examples/llava/llava-cli.cpp | 5 +++++ examples/main/main.cpp | 4 ++++ 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 484e67334..ba1ecf0e5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1371,14 +1371,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { std::replace(arg.begin(), arg.end(), '_', '-'); } - if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { throw std::invalid_argument("error: unknown argument: " + arg); } - } - - if (invalid_param) { - throw std::invalid_argument("error: invalid parameter for argument: " + arg); + if (invalid_param) { + throw std::invalid_argument("error: invalid parameter for argument: " + arg); + } } if (params.prompt_cache_all && diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index 2a1301569..fecb7cd71 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -142,6 +142,9 @@ namespace grammar_parser { pos++; last_sym_start = out_elements.size(); while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } auto char_pair = parse_char(pos); pos = char_pair.second; out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); @@ -156,6 +159,9 @@ namespace grammar_parser { } last_sym_start = out_elements.size(); while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } auto char_pair = parse_char(pos); pos = char_pair.second; enum llama_gretype type = last_sym_start < out_elements.size() @@ -164,6 +170,9 @@ namespace grammar_parser { out_elements.push_back({type, char_pair.first}); if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } auto endchar_pair = parse_char(pos + 1); pos = endchar_pair.second; out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 157a680b5..da60ddf2f 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -189,6 +189,11 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + if (!ctx_sampling) { + fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); + exit(1); + } + std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index f3e445c16..9dee41001 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -523,6 +523,10 @@ int main(int argc, char ** argv) { } struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + if (!ctx_sampling) { + fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); + exit(1); + } while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict