diff --git a/Makefile b/Makefile index 332496cfc..89287831f 100644 --- a/Makefile +++ b/Makefile @@ -927,7 +927,6 @@ OBJ_COMMON = \ common/ngram-cache.o \ common/sampling.o \ common/train.o \ - common/grammar-parser.o \ common/build-info.o \ common/json-schema-to-grammar.o @@ -1167,11 +1166,6 @@ common/console.o: \ common/console.h $(CXX) $(CXXFLAGS) -c $< -o $@ -common/grammar-parser.o: \ - common/grammar-parser.cpp \ - common/grammar-parser.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - common/json-schema-to-grammar.o: \ common/json-schema-to-grammar.cpp \ common/json-schema-to-grammar.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 761971d68..2c72793b8 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -58,8 +58,6 @@ add_library(${TARGET} STATIC sampling.cpp console.h console.cpp - grammar-parser.h - grammar-parser.cpp json.hpp json-schema-to-grammar.cpp train.h diff --git a/common/common.cpp b/common/common.cpp index 9fa184725..077e4c0a1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -353,16 +353,15 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) } bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { - bool invalid_param = false; - std::string arg; - const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sparams; - for (int i = 1; i < argc; i++) { - arg = argv[i]; + const std::string arg_prefix = "--"; + + std::string arg = argv[i]; if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { std::replace(arg.begin(), arg.end(), '_', '-'); } + + bool invalid_param = false; if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { throw std::invalid_argument("error: unknown argument: " + arg); } @@ -386,11 +385,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { get_env("HF_TOKEN", params.hf_token); } + auto & sparams = params.sparams; + if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); string_process_escapes(params.input_suffix); - string_process_escapes(sparams.cfg_negative_prompt); for (auto & antiprompt : params.antiprompt) { string_process_escapes(antiprompt); } @@ -401,6 +401,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.kv_overrides.back().key[0] = 0; } + if (sparams.seed == LLAMA_DEFAULT_SEED) { + sparams.seed = time(NULL); + } + return true; } @@ -526,12 +530,10 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { const char split_delim = ','; - llama_sampling_params & sparams = params.sparams; + auto & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { CHECK_ARG - // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context. - params.seed = std::stoul(argv[i]); sparams.seed = std::stoul(argv[i]); return true; } @@ -842,12 +844,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa if (arg == "--samplers") { CHECK_ARG const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true); + sparams.samplers = llama_sampling_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { CHECK_ARG - sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]); + sparams.samplers = llama_sampling_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { @@ -873,7 +875,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "--typical") { CHECK_ARG - sparams.typical_p = std::stof(argv[i]); + sparams.typ_p = std::stof(argv[i]); return true; } if (arg == "--repeat-last-n") { @@ -922,30 +924,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.mirostat_tau = std::stof(argv[i]); return true; } - if (arg == "--cfg-negative-prompt") { - CHECK_ARG - sparams.cfg_negative_prompt = argv[i]; - return true; - } - if (arg == "--cfg-negative-prompt-file") { - CHECK_ARG - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); - if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { - sparams.cfg_negative_prompt.pop_back(); - } - return true; - } - if (arg == "--cfg-scale") { - CHECK_ARG - sparams.cfg_scale = std::stof(argv[i]); - return true; - } if (arg == "-b" || arg == "--batch-size") { CHECK_ARG params.n_batch = std::stoi(argv[i]); @@ -1353,7 +1331,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--ignore-eos") { - params.ignore_eos = true; + sparams.ignore_eos = true; return true; } if (arg == "--penalize-nl") { @@ -1368,7 +1346,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa std::string value_str; try { if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + sparams.logit_bias.push_back({key, bias}); } else { throw std::exception(); @@ -1715,13 +1694,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa #endif void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { - const llama_sampling_params & sparams = params.sparams; + const auto & sparams = params.sparams; std::string sampler_type_chars; std::string sampler_type_names; - for (const auto sampler_type : sparams.samplers_sequence) { - sampler_type_chars += static_cast(sampler_type); - sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";"; + for (const auto & sampler : sparams.samplers) { + sampler_type_chars += llama_sampling_type_to_chr(sampler); + sampler_type_names += llama_sampling_type_to_str(sampler) + ";"; } sampler_type_names.pop_back(); @@ -1756,7 +1735,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" }); options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" }); options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" }); - options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed }); options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.cpuparams.n_threads }); options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" }); options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); @@ -1836,18 +1814,19 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param " --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" }); options.push_back({ "sampling" }); + options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed }); options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" "(default: %s)", sampler_type_names.c_str() }); options.push_back({ "*", " --sampling-seq SEQUENCE", "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); - options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp }); + options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp }); options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k }); - options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p }); - options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p }); - options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z }); - options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p }); + options.push_back({ "*", " --top-p P", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p }); + options.push_back({ "*", " --min-p P", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p }); + options.push_back({ "*", " --tfs P", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z }); + options.push_back({ "*", " --typical P", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p }); options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n }); options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat }); options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present }); @@ -1862,11 +1841,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); - options.push_back({ "main", " --cfg-negative-prompt PROMPT", - "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() }); - options.push_back({ "main", " --cfg-negative-prompt-file FNAME", - "negative prompt file to use for guidance" }); - options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); options.push_back({ "main", " --chat-template JINJA_TEMPLATE", "set custom jinja chat template (default: template taken from model's metadata)\n" "if suffix/prefix are specified, template will be disabled\n" @@ -2513,8 +2487,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { llama_lora_adapters_apply(lctx, iparams.lora_adapters); } - if (params.ignore_eos) { - params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + if (params.sparams.ignore_eos && llama_token_eos(model) == -1) { + fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sparams.ignore_eos = false; } if (params.warmup) { @@ -2543,7 +2518,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } llama_kv_cache_clear(lctx); llama_synchronize(lctx); - llama_reset_timings(lctx); + llama_reset_timings(lctx, nullptr); } iparams.model = model; @@ -2622,7 +2597,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? params.cpuparams.n_threads : params.cpuparams_batch.n_threads; - cparams.seed = params.seed; cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; @@ -3508,7 +3482,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sparams; + const auto & sparams = params.sparams; fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); @@ -3559,8 +3533,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); - fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); @@ -3571,10 +3543,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; - fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); + fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false"); yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); @@ -3585,11 +3554,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); fprintf(stream, "logit_bias:\n"); - for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { - continue; - } - fprintf(stream, " %d: %f", lb.first, lb.second); + for (const auto & logit_bias : sparams.logit_bias) { + fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias); } fprintf(stream, "lora:\n"); @@ -3642,7 +3608,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); @@ -3656,7 +3621,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); - fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); + fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); } diff --git a/common/common.h b/common/common.h index cb5e7f6df..f9dbd0e49 100644 --- a/common/common.h +++ b/common/common.h @@ -77,8 +77,6 @@ struct cpu_params { }; struct gpt_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 0; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) @@ -120,8 +118,7 @@ struct gpt_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - // // sampling parameters - struct llama_sampling_params sparams; + struct gpt_sampling_params sparams; std::string model = ""; // model path std::string model_draft = ""; // draft model for speculative decoding @@ -185,7 +182,6 @@ struct gpt_params { bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool ignore_eos = false; // ignore generated EOS tokens bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp deleted file mode 100644 index 438452eab..000000000 --- a/common/grammar-parser.cpp +++ /dev/null @@ -1,539 +0,0 @@ -#include "grammar-parser.h" -#include -#include -#include -#include -#include -#include - -namespace grammar_parser { - // NOTE: assumes valid utf8 (but checks for overrun) - // copied from llama.cpp - static std::pair decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = src + len; // may overrun! - const char * pos = src + 1; - for ( ; pos < end && *pos; pos++) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - } - return std::make_pair(value, pos); - } - - static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - auto result = state.symbol_ids.emplace(std::string(src, len), next_id); - return result.first->second; - } - - static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; - return next_id; - } - - static void add_rule( - parse_state & state, - uint32_t rule_id, - const std::vector & rule) { - if (state.rules.size() <= rule_id) { - state.rules.resize(rule_id + 1); - } - state.rules[rule_id] = rule; - } - - static bool is_digit_char(char c) { - return '0' <= c && c <= '9'; - } - - static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); - } - - static std::pair parse_hex(const char * src, int size) { - const char * pos = src; - const char * end = src + size; - uint32_t value = 0; - for ( ; pos < end && *pos; pos++) { - value <<= 4; - char c = *pos; - if ('a' <= c && c <= 'f') { - value += c - 'a' + 10; - } else if ('A' <= c && c <= 'F') { - value += c - 'A' + 10; - } else if ('0' <= c && c <= '9') { - value += c - '0'; - } else { - break; - } - } - if (pos != end) { - throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); - } - return std::make_pair(value, pos); - } - - static const char * parse_space(const char * src, bool newline_ok) { - const char * pos = src; - while (*pos == ' ' || *pos == '\t' || *pos == '#' || - (newline_ok && (*pos == '\r' || *pos == '\n'))) { - if (*pos == '#') { - while (*pos && *pos != '\r' && *pos != '\n') { - pos++; - } - } else { - pos++; - } - } - return pos; - } - - static const char * parse_name(const char * src) { - const char * pos = src; - while (is_word_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting name at ") + src); - } - return pos; - } - - static const char * parse_int(const char * src) { - const char * pos = src; - while (is_digit_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting integer at ") + src); - } - return pos; - } - - static std::pair parse_char(const char * src) { - if (*src == '\\') { - switch (src[1]) { - case 'x': return parse_hex(src + 2, 2); - case 'u': return parse_hex(src + 2, 4); - case 'U': return parse_hex(src + 2, 8); - case 't': return std::make_pair('\t', src + 2); - case 'r': return std::make_pair('\r', src + 2); - case 'n': return std::make_pair('\n', src + 2); - case '\\': - case '"': - case '[': - case ']': - return std::make_pair(src[1], src + 2); - default: - throw std::runtime_error(std::string("unknown escape at ") + src); - } - } else if (*src) { - return decode_utf8(src); - } - throw std::runtime_error("unexpected end of input"); - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested); - - static const char * parse_sequence( - parse_state & state, - const char * src, - const std::string & rule_name, - std::vector & out_elements, - bool is_nested) { - size_t last_sym_start = out_elements.size(); - const char * pos = src; - - auto handle_repetitions = [&](int min_times, int max_times) { - - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // the following rewrite rules: - // S{m,n} --> S S S (m times) S'(n-m) - // S'(x) ::= S S'(x-1) | - // (... n-m definitions of these S' rules ...) - // S'(1) ::= S | - // S{m,} --> S S S (m times) S' - // S' ::= S S' | - // S* --> S{0,} - // --> S' ::= S S' | - // S+ --> S{1,} - // --> S S' - // S' ::= S S' | - // S? --> S{0,1} - // --> S' - // S' ::= S | - - std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); - if (min_times == 0) { - out_elements.resize(last_sym_start); - } else { - // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { - out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); - } - } - - uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - std::vector rec_rule(previous_elements); - for (int i = 0; i < n_opt; i++) { - rec_rule.resize(previous_elements.size()); - uint32_t rec_rule_id = generate_symbol_id(state, rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - } - rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rec_rule_id, rec_rule); - last_rec_rule_id = rec_rule_id; - } - if (n_opt > 0) { - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - } - }; - - while (*pos) { - if (*pos == '"') { // literal string - pos++; - last_sym_start = out_elements.size(); - while (*pos != '"') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '[') { // char range(s) - pos++; - enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - if (*pos == '^') { - pos++; - start_type = LLAMA_GRETYPE_CHAR_NOT; - } - last_sym_start = out_elements.size(); - while (*pos != ']') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - enum llama_gretype type = last_sym_start < out_elements.size() - ? LLAMA_GRETYPE_CHAR_ALT - : start_type; - - out_elements.push_back({type, char_pair.first}); - if (pos[0] == '-' && pos[1] != ']') { - if (!pos[1]) { - throw std::runtime_error("unexpected end of input"); - } - auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - } - } - pos = parse_space(pos + 1, is_nested); - } else if (is_word_char(*pos)) { // rule reference - const char * name_end = parse_name(pos); - uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); - pos = parse_space(name_end, is_nested); - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - } else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule - pos = parse_space(pos + 1, true); - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - last_sym_start = out_elements.size(); - // output reference to synthesized rule - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - if (*pos != ')') { - throw std::runtime_error(std::string("expecting ')' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '.') { // any char - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, -1); - } else if (*pos == '+') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(1, -1); - } else if (*pos == '?') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, 1); - } else if (*pos == '{') { - pos = parse_space(pos + 1, is_nested); - - if (!is_digit_char(*pos)) { - throw std::runtime_error(std::string("expecting an int at ") + pos); - } - const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - - int max_times = -1; - - if (*pos == '}') { - max_times = min_times; - pos = parse_space(pos + 1, is_nested); - } else if (*pos == ',') { - pos = parse_space(pos + 1, is_nested); - - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else { - throw std::runtime_error(std::string("expecting ',' at ") + pos); - } - handle_repetitions(min_times, max_times); - } else { - break; - } - } - return pos; - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested) { - std::vector rule; - const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); - while (*pos == '|') { - rule.push_back({LLAMA_GRETYPE_ALT, 0}); - pos = parse_space(pos + 1, true); - pos = parse_sequence(state, pos, rule_name, rule, is_nested); - } - rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rule_id, rule); - return pos; - } - - static const char * parse_rule(parse_state & state, const char * src) { - const char * name_end = parse_name(src); - const char * pos = parse_space(name_end, false); - size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(state, src, name_len); - const std::string name(src, name_len); - - if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::runtime_error(std::string("expecting ::= at ") + pos); - } - pos = parse_space(pos + 3, true); - - pos = parse_alternates(state, pos, name, rule_id, false); - - if (*pos == '\r') { - pos += pos[1] == '\n' ? 2 : 1; - } else if (*pos == '\n') { - pos++; - } else if (*pos) { - throw std::runtime_error(std::string("expecting newline or end at ") + pos); - } - return parse_space(pos, true); - } - - parse_state parse(const char * src) { - try { - parse_state state; - const char * pos = parse_space(src, true); - while (*pos) { - pos = parse_rule(state, pos); - } - // Validate the state to ensure that all rules are defined - for (const auto & rule : state.rules) { - if (rule.empty()) { - throw std::runtime_error("Undefined rule"); - } - for (const auto & elem : rule) { - if (elem.type == LLAMA_GRETYPE_RULE_REF) { - // Ensure that the rule at that location exists - if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { - // Get the name of the rule that is missing - for (const auto & kv : state.symbol_ids) { - if (kv.second == elem.value) { - throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); - } - } - } - } - } - } - return state; - } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); - return parse_state(); - } - } - - static void print_grammar_char(FILE * file, uint32_t c) { - if (0x20 <= c && c <= 0x7f) { - fprintf(file, "%c", static_cast(c)); - } else { - // cop out of encoding UTF-8 - fprintf(file, "", c); - } - } - - static bool is_char_element(llama_grammar_element elem) { - switch (elem.type) { - case LLAMA_GRETYPE_CHAR: return true; - case LLAMA_GRETYPE_CHAR_NOT: return true; - case LLAMA_GRETYPE_CHAR_ALT: return true; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; - case LLAMA_GRETYPE_CHAR_ANY: return true; - default: return false; - } - } - - static void print_rule_binary(FILE * file, const std::vector & rule) { - for (auto elem : rule) { - switch (elem.type) { - case LLAMA_GRETYPE_END: fprintf(file, "END"); break; - case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; - case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; - case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; - case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; - case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; - } - switch (elem.type) { - case LLAMA_GRETYPE_END: - case LLAMA_GRETYPE_ALT: - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "(%u) ", elem.value); - break; - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "(\""); - print_grammar_char(file, elem.value); - fprintf(file, "\") "); - break; - } - } - fprintf(file, "\n"); - } - - static void print_rule( - FILE * file, - uint32_t rule_id, - const std::vector & rule, - const std::map & symbol_id_names) { - if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { - throw std::runtime_error( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); - } - fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - for (size_t i = 0, end = rule.size() - 1; i < end; i++) { - llama_grammar_element elem = rule[i]; - switch (elem.type) { - case LLAMA_GRETYPE_END: - throw std::runtime_error( - "unexpected end of rule: " + std::to_string(rule_id) + "," + - std::to_string(i)); - case LLAMA_GRETYPE_ALT: - fprintf(file, "| "); - break; - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); - break; - case LLAMA_GRETYPE_CHAR: - fprintf(file, "["); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_NOT: - fprintf(file, "[^"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - fprintf(file, "-"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ALT: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "."); - break; - } - if (is_char_element(elem)) { - switch (rule[i + 1].type) { - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ANY: - break; - default: - fprintf(file, "] "); - } - } - } - fprintf(file, "\n"); - } - - void print_grammar(FILE * file, const parse_state & state) { - try { - std::map symbol_id_names; - for (const auto & kv : state.symbol_ids) { - symbol_id_names[kv.second] = kv.first; - } - for (size_t i = 0, end = state.rules.size(); i < end; i++) { - // fprintf(file, "%zu: ", i); - // print_rule_binary(file, state.rules[i]); - print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); - // fprintf(file, "\n"); - } - } catch (const std::exception & err) { - fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); - } - } - - std::vector parse_state::c_rules() { - std::vector ret; - ret.reserve(rules.size()); - for (const auto & rule : rules) { - ret.push_back(rule.data()); - } - return ret; - } -} diff --git a/common/grammar-parser.h b/common/grammar-parser.h deleted file mode 100644 index 9037d7272..000000000 --- a/common/grammar-parser.h +++ /dev/null @@ -1,29 +0,0 @@ -// Implements a parser for an extended Backus-Naur form (BNF), producing the -// binary context-free grammar format specified by llama.h. Supports character -// ranges, grouping, and repetition operators. As an example, a grammar for -// arithmetic might look like: -// -// root ::= expr -// expr ::= term ([-+*/] term)* -// term ::= num | "(" space expr ")" space -// num ::= [0-9]+ space -// space ::= [ \t\n]* - -#pragma once -#include "llama.h" -#include -#include -#include -#include - -namespace grammar_parser { - struct parse_state { - std::map symbol_ids; - std::vector> rules; - - std::vector c_rules(); - }; - - parse_state parse(const char * src); - void print_grammar(FILE * file, const parse_state & state); -} diff --git a/common/sampling.cpp b/common/sampling.cpp index 079e40516..96cfbe0ef 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,141 +1,28 @@ -#define LLAMA_API_INTERNAL #include "sampling.h" -#include -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { - struct llama_sampling_context * result = new llama_sampling_context(); +#include "common.h" - result->params = params; - result->grammar = nullptr; - - // if there is a grammar, parse it - if (!params.grammar.empty()) { - result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - - // will be empty (default) if there are parse errors - if (result->parsed_grammar.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); - delete result; - return nullptr; - } - - // Ensure that there is a "root" node. - if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); - delete result; - return nullptr; - } - - std::vector grammar_rules(result->parsed_grammar.c_rules()); - - struct llama_grammar * grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); - if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); - } - result->grammar = grammar; - } - - result->prev.resize(params.n_prev); - - result->n_valid = 0; - - llama_sampling_set_rng_seed(result, params.seed); - - return result; -} - -void llama_sampling_free(struct llama_sampling_context * ctx) { - if (ctx->grammar != NULL) { - llama_grammar_free(ctx->grammar); - } - - delete ctx; -} - -void llama_sampling_reset(llama_sampling_context * ctx) { - if (ctx->grammar != NULL) { - llama_grammar_free(ctx->grammar); - ctx->grammar = NULL; - } - - if (!ctx->parsed_grammar.rules.empty()) { - std::vector grammar_rules(ctx->parsed_grammar.c_rules()); - - struct llama_grammar * grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); - if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); - } - ctx->grammar = grammar; - } - - std::fill(ctx->prev.begin(), ctx->prev.end(), 0); - ctx->cur.clear(); - ctx->n_valid = 0; -} - -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = std::random_device{}(); - } - ctx->rng.seed(seed); -} - -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { - if (dst->grammar) { - llama_grammar_free(dst->grammar); - dst->grammar = nullptr; - } - - if (src->grammar) { - dst->grammar = llama_grammar_copy(src->grammar); - } - - dst->prev = src->prev; -} - -llama_token llama_sampling_last(llama_sampling_context * ctx) { - return ctx->prev.back(); -} - -std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { - const int size = ctx_sampling->prev.size(); - - n = std::min(n, size); - - std::string result; - - for (int i = size - n; i < size; i++) { - result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]); - } - - return result; -} - -std::string llama_sampling_print(const llama_sampling_params & params) { +std::string gpt_sampling_params::print_all() const { char result[1024]; snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", - params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, - params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, - params.mirostat, params.mirostat_eta, params.mirostat_tau); + penalty_last_n, penalty_repeat, penalty_freq, penalty_present, + top_k, tfs_z, top_p, min_p, typ_p, temp, + mirostat, mirostat_eta, mirostat_tau); return std::string(result); } -std::string llama_sampling_order_print(const llama_sampling_params & params) { +std::string gpt_sampling_params::print_samplers() const { std::string result = "CFG -> Penalties "; - if (params.mirostat == 0) { - for (auto sampler_type : params.samplers_sequence) { - const auto sampler_type_name = llama_sampling_type_to_str(sampler_type); - if (!sampler_type_name.empty()) { - result += "-> " + sampler_type_name + " "; + if (mirostat == 0) { + for (const auto & sampler : samplers) { + const auto name = llama_sampling_type_to_str(sampler); + if (!name.empty()) { + result += "-> " + name + " "; } } } else { @@ -145,316 +32,191 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) { return result; } -std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { - switch (sampler_type) { - case llama_sampler_type::TOP_K: return "top_k"; - case llama_sampler_type::TFS_Z: return "tfs_z"; - case llama_sampler_type::TYPICAL_P: return "typical_p"; - case llama_sampler_type::TOP_P: return "top_p"; - case llama_sampler_type::MIN_P: return "min_p"; - case llama_sampler_type::TEMPERATURE: return "temperature"; +struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) { + llama_sampling_params lparams = llama_sampling_default_params(); + + lparams.seed = params.seed; + lparams.n_prev = params.n_prev; + lparams.n_probs = params.n_probs; + lparams.min_keep = params.min_keep; + lparams.top_k = params.top_k; + lparams.top_p = params.top_p; + lparams.min_p = params.min_p; + lparams.tfs_z = params.tfs_z; + lparams.typ_p = params.typ_p; + lparams.temp = params.temp; + lparams.dynatemp_range = params.dynatemp_range; + lparams.dynatemp_exponent = params.dynatemp_exponent; + lparams.penalty_last_n = params.penalty_last_n; + lparams.penalty_repeat = params.penalty_repeat; + lparams.penalty_freq = params.penalty_freq; + lparams.penalty_present = params.penalty_present; + lparams.mirostat = params.mirostat; + lparams.mirostat_tau = params.mirostat_tau; + lparams.mirostat_eta = params.mirostat_eta; + lparams.penalize_nl = params.penalize_nl; + lparams.ignore_eos = params.ignore_eos; + + lparams.n_samplers = params.samplers.size(); + for (int i = 0; i < lparams.n_samplers; i++) { + lparams.samplers[i] = params.samplers[i]; + } + + struct llama_sampling * result = llama_sampling_init(model, lparams); + + llama_sampling_set_grammar (result, params.grammar.c_str(), "root"); + llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data()); + + return result; +} + +void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst) { + if (dst) { + llama_sampling_free(dst); + } + + dst = llama_sampling_cp(src); +} + +llama_token llama_sampling_sample( + struct llama_sampling * smpl, + struct llama_context * ctx, + int idx) { + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + + // first, sample the token without any grammar constraints + const llama_token id = llama_sampling_sample(smpl, nullptr); + + // create an array with a single token data element for the sampled id + llama_token_data single_token_data = { id, 1.0f, 0.0f }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; + + llama_sampling_grammar(smpl, &single_token_data_array); + + // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY + const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + return id; + } + + // if the token is not valid, sample again, after applying the grammar constraints + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + + llama_sampling_grammar(smpl, nullptr); + + return llama_sampling_sample(smpl, nullptr); +} + +std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) { + n = std::min(n, llama_sampling_n_prev(smpl)); + + if (n <= 0) { + return ""; + } + + std::string result; + result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab + + for (int i = n - 1; i >= 0; i--) { + const llama_token id = llama_sampling_prev(smpl, i); + + GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); + + result += llama_token_to_piece(ctx_main, id); + } + + return result; +} + +char llama_sampling_type_to_chr(llama_sampler_type sampler) { + switch (sampler) { + case LLAMA_SAMPLER_TYPE_TOP_K: return 'k'; + case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f'; + case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y'; + case LLAMA_SAMPLER_TYPE_TOP_P: return 'p'; + case LLAMA_SAMPLER_TYPE_MIN_P: return 'm'; + case LLAMA_SAMPLER_TYPE_TEMPERATURE: return 't'; + default : return '?'; + } +} + +std::string llama_sampling_type_to_str(llama_sampler_type sampler) { + switch (sampler) { + case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k"; + case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z"; + case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; + case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p"; + case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p"; + case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature"; default : return ""; } } std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { std::unordered_map sampler_canonical_name_map { - {"top_k", llama_sampler_type::TOP_K}, - {"top_p", llama_sampler_type::TOP_P}, - {"typical_p", llama_sampler_type::TYPICAL_P}, - {"min_p", llama_sampler_type::MIN_P}, - {"tfs_z", llama_sampler_type::TFS_Z}, - {"temperature", llama_sampler_type::TEMPERATURE} + { "top_k", LLAMA_SAMPLER_TYPE_TOP_K }, + { "top_p", LLAMA_SAMPLER_TYPE_TOP_P }, + { "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "min_p", LLAMA_SAMPLER_TYPE_MIN_P }, + { "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z }, + { "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE }, }; // since samplers names are written multiple ways // make it ready for both system names and input names std::unordered_map sampler_alt_name_map { - {"top-k", llama_sampler_type::TOP_K}, - {"top-p", llama_sampler_type::TOP_P}, - {"nucleus", llama_sampler_type::TOP_P}, - {"typical-p", llama_sampler_type::TYPICAL_P}, - {"typical", llama_sampler_type::TYPICAL_P}, - {"min-p", llama_sampler_type::MIN_P}, - {"tfs-z", llama_sampler_type::TFS_Z}, - {"tfs", llama_sampler_type::TFS_Z}, - {"temp", llama_sampler_type::TEMPERATURE} + { "top-k", LLAMA_SAMPLER_TYPE_TOP_K }, + { "top-p", LLAMA_SAMPLER_TYPE_TOP_P }, + { "nucleus", LLAMA_SAMPLER_TYPE_TOP_P }, + { "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "min-p", LLAMA_SAMPLER_TYPE_MIN_P }, + { "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z }, + { "tfs", LLAMA_SAMPLER_TYPE_TFS_Z }, + { "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE }, }; - std::vector sampler_types; - sampler_types.reserve(names.size()); - for (const auto & name : names) - { - auto sampler_item = sampler_canonical_name_map.find(name); - if (sampler_item != sampler_canonical_name_map.end()) - { - sampler_types.push_back(sampler_item->second); - } - else - { - if (allow_alt_names) - { - sampler_item = sampler_alt_name_map.find(name); - if (sampler_item != sampler_alt_name_map.end()) - { - sampler_types.push_back(sampler_item->second); - } - } - } - } - return sampler_types; -} + std::vector samplers; + samplers.reserve(names.size()); -std::vector llama_sampling_types_from_chars(const std::string & names_string) { - std::unordered_map sampler_name_map { - {'k', llama_sampler_type::TOP_K}, - {'p', llama_sampler_type::TOP_P}, - {'y', llama_sampler_type::TYPICAL_P}, - {'m', llama_sampler_type::MIN_P}, - {'f', llama_sampler_type::TFS_Z}, - {'t', llama_sampler_type::TEMPERATURE} - }; - - std::vector sampler_types; - sampler_types.reserve(names_string.size()); - for (const auto & c : names_string) { - const auto sampler_item = sampler_name_map.find(c); - if (sampler_item != sampler_name_map.end()) { - sampler_types.push_back(sampler_item->second); - } - } - return sampler_types; -} - -// no reasons to expose this function in header -static void sampler_queue( - struct llama_context * ctx_main, - const llama_sampling_params & params, - llama_token_data_array & cur_p, - size_t min_keep) { - const float temp = params.temp; - const float dynatemp_range = params.dynatemp_range; - const float dynatemp_exponent = params.dynatemp_exponent; - const int32_t top_k = params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const std::vector & samplers_sequence = params.samplers_sequence; - - for (auto sampler_type : samplers_sequence) { - switch (sampler_type) { - case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; - case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; - case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; - case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; - case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; - case llama_sampler_type::TEMPERATURE: - if (dynatemp_range > 0) { - float dynatemp_min = std::max(0.0f, temp - dynatemp_range); - float dynatemp_max = std::max(0.0f, temp + dynatemp_range); - llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); - } else { - llama_sample_temp(ctx_main, &cur_p, temp); - } - break; - default : break; - } - } -} - -static llama_token llama_sampling_sample_impl( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool is_resampling) { - const llama_sampling_params & params = ctx_sampling->params; - - const float temp = params.temp; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - - std::vector original_logits; - auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); - if (ctx_sampling->grammar != NULL && !is_resampling) { - GGML_ASSERT(!original_logits.empty()); - } - llama_token id = 0; - - if (temp < 0.0) { - // greedy sampling, with probs - llama_sample_softmax(ctx_main, &cur_p); - id = cur_p.data[0].id; - } else if (temp == 0.0) { - // greedy sampling, no probs - id = llama_sample_token_greedy(ctx_main, &cur_p); - } else { - if (mirostat == 1) { - const int mirostat_m = 100; - llama_sample_temp(ctx_main, &cur_p, temp); - id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); - } else if (mirostat == 2) { - llama_sample_temp(ctx_main, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); + for (const auto & name : names) { + auto sampler = sampler_canonical_name_map.find(name); + if (sampler != sampler_canonical_name_map.end()) { + samplers.push_back(sampler->second); } else { - // temperature sampling - size_t min_keep = std::max(1, params.min_keep); - - sampler_queue(ctx_main, params, cur_p, min_keep); - - id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); - - //{ - // const int n_top = 10; - // LOG("top %d candidates:\n", n_top); - - // for (int i = 0; i < n_top; i++) { - // const llama_token id = cur_p.data[i].id; - // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); - // } - //} - - //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); - } - } - - if (ctx_sampling->grammar != NULL && !is_resampling) { - // Get a pointer to the logits - float * logits = llama_get_logits_ith(ctx_main, idx); - - // Create an array with a single token data element for the sampled id - llama_token_data single_token_data = {id, logits[id], 0.0f}; - llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; - - // Apply grammar constraints to the single token - llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array); - - // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY - bool is_valid = single_token_data_array.data[0].logit != -INFINITY; - - // If the token is not valid according to the grammar, perform resampling - if (!is_valid) { - LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); - - // Restore logits from the copy - std::copy(original_logits.begin(), original_logits.end(), logits); - - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true); - } - } - - ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size; - - return id; -} - -static llama_token_data_array llama_sampling_prepare_impl( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool apply_grammar, - std::vector * original_logits) { - const llama_sampling_params & params = ctx_sampling->params; - - const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - - const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; - const float penalty_repeat = params.penalty_repeat; - const float penalty_freq = params.penalty_freq; - const float penalty_present = params.penalty_present; - - const bool penalize_nl = params.penalize_nl; - - auto & prev = ctx_sampling->prev; - auto & cur = ctx_sampling->cur; - - // Get a pointer to the logits - float * logits = llama_get_logits_ith(ctx_main, idx); - - if (ctx_sampling->grammar != NULL && !apply_grammar) { - GGML_ASSERT(original_logits != NULL); - // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. - *original_logits = {logits, logits + n_vocab}; - } - - // apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - if (ctx_cfg) { - float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); - llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); - } - - cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - - // apply penalties - const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; - const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); - if (penalty_tokens_used_size) { - const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - - llama_sample_repetition_penalties(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); - - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { - cur_p.data[idx].logit = nl_logit; - break; + if (allow_alt_names) { + sampler = sampler_alt_name_map.find(name); + if (sampler != sampler_alt_name_map.end()) { + samplers.push_back(sampler->second); } } } } - // apply grammar checks before sampling logic - if (apply_grammar && ctx_sampling->grammar != NULL) { - llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p); + return samplers; +} + +std::vector llama_sampling_types_from_chars(const std::string & chars) { + std::unordered_map sampler_name_map { + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE } + }; + + std::vector samplers; + samplers.reserve(chars.size()); + + for (const auto & c : chars) { + const auto sampler = sampler_name_map.find(c); + if (sampler != sampler_name_map.end()) { + samplers.push_back(sampler->second); + } } - return cur_p; -} - -llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx) { - // Call the implementation function with is_resampling set to false by default - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false); -} - -llama_token_data_array llama_sampling_prepare( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool apply_grammar, - std::vector * original_logits) { - return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); -} - -void llama_sampling_accept( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - llama_token id, - bool apply_grammar) { - ctx_sampling->prev.erase(ctx_sampling->prev.begin()); - ctx_sampling->prev.push_back(id); - - if (ctx_sampling->grammar != NULL && apply_grammar) { - llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); - } + return samplers; } diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8b..b96bbce1c 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -2,159 +2,78 @@ #include "llama.h" -#include "grammar-parser.h" - -#include #include -#include #include -// sampler types -enum class llama_sampler_type : char { - TOP_K = 'k', - TOP_P = 'p', - MIN_P = 'm', - TFS_Z = 'f', - TYPICAL_P = 'y', - TEMPERATURE = 't' -}; - // sampling parameters -typedef struct llama_sampling_params { - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context +typedef struct gpt_sampling_params { + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling - std::vector samplers_sequence = { - llama_sampler_type::TOP_K, - llama_sampler_type::TFS_Z, - llama_sampler_type::TYPICAL_P, - llama_sampler_type::TOP_P, - llama_sampler_type::MIN_P, - llama_sampler_type::TEMPERATURE + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + bool ignore_eos = false; + + std::vector samplers = { + LLAMA_SAMPLER_TYPE_TOP_K, + LLAMA_SAMPLER_TYPE_TFS_Z, + LLAMA_SAMPLER_TYPE_TYPICAL_P, + LLAMA_SAMPLER_TYPE_TOP_P, + LLAMA_SAMPLER_TYPE_MIN_P, + LLAMA_SAMPLER_TYPE_TEMPERATURE }; - std::string grammar; // optional BNF-like grammar to constrain sampling + std::string grammar; // optional BNF-like grammar to constrain sampling - // Classifier-Free Guidance - // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // how strong is guidance + std::vector logit_bias; // logit biases to apply - std::unordered_map logit_bias; // logit bias for specific tokens + // print the parameters into a string + std::string print_all() const; - std::vector penalty_prompt_tokens; - bool use_penalty_prompt_tokens = false; -} llama_sampling_params; + // print the samplers into a string + std::string print_samplers() const; +} gpt_sampling_params; -// general sampler context -// TODO: move to llama.h -struct llama_sampling_context { - // parameters that will be used for sampling - llama_sampling_params params; +// overload of llama_sampling_init using gpt_sampling_params +struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params); - // mirostat sampler state - float mirostat_mu; +void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst); - llama_grammar * grammar; - - // internal - grammar_parser::parse_state parsed_grammar; - - // TODO: replace with ring-buffer - std::vector prev; - std::vector cur; - size_t n_valid; // Number of correct top tokens with correct probabilities. - - std::mt19937 rng; -}; - -#include "common.h" - -// Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); - -void llama_sampling_free(struct llama_sampling_context * ctx); - -// Reset the sampler context -// - clear prev tokens -// - reset grammar -void llama_sampling_reset(llama_sampling_context * ctx); - -// Set the sampler seed -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); - -// Copy the sampler context -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); - -// Get the last sampled token -llama_token llama_sampling_last(llama_sampling_context * ctx); - -// Get a string representation of the last sampled tokens -std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); - -// Print sampling parameters into a string -std::string llama_sampling_print(const llama_sampling_params & params); - -// Print sampling order into a string -std::string llama_sampling_order_print(const llama_sampling_params & params); - -std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); - -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector llama_sampling_types_from_chars(const std::string & names_string); - -// this is a common sampling function used across the examples for convenience -// it can serve as a starting point for implementing your own sampling function -// Note: When using multiple sequences, it is the caller's responsibility to call -// llama_sampling_reset when a sequence ends +// common sampling implementation: // -// required: -// - ctx_main: context to use for sampling -// - ctx_sampling: sampling-specific context -// -// optional: -// - ctx_cfg: context to use for classifier-free guidance -// - idx: sample from llama_get_logits_ith(ctx, idx) -// -// returns: -// - token: sampled token -// - candidates: vector of candidate tokens +// - set logits +// - apply the configured sampling constraints +// - check if the token fits the grammar (if any) +// - if not: resample by first applying the grammar constraints and then sampling again (slower path) // llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - int idx = -1); + struct llama_sampling * smpl, + struct llama_context * ctx, + int idx); -// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. -llama_token_data_array llama_sampling_prepare( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - int idx = 0, - bool apply_grammar = true, - std::vector * original_logits = nullptr); +// helpers -void llama_sampling_accept( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - llama_token id, - bool apply_grammar); +// get a string representation of the last accepted tokens +std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n); + +char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type); +std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type); + +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector llama_sampling_types_from_chars(const std::string & chars); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 25e7c775a..afd3f11d2 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -200,7 +200,7 @@ int main(int argc, char ** argv) { } } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_batch_free(batch); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 616494d2d..81763217a 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo print("Failed to load model") exit(1) } - defer { llama_free_model(model) } @@ -37,7 +36,6 @@ var tokens = tokenize(text: prompt, add_bos: true) let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel) var context_params = llama_context_default_params() -context_params.seed = 1234 context_params.n_ctx = n_kv_req context_params.n_batch = UInt32(max(n_len, n_parallel)) context_params.n_threads = 8 @@ -48,11 +46,24 @@ guard context != nil else { print("Failed to initialize context") exit(1) } - defer { llama_free(context) } +var sparams = llama_sampling_params() +sparams.top_k = 40 +sparams.top_p = 0.9 +sparams.temp = 0.4 + +let smpl = llama_sampling_init(model, sparams) +guard smpl != nil else { + print("Failed to initialize sampling") + exit(1) +} +defer { + llama_sampling_free(smpl) +} + let n_ctx = llama_n_ctx(context) print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n") @@ -125,32 +136,17 @@ while n_cur <= n_len { continue } - var n_vocab = llama_n_vocab(model) var logits = llama_get_logits_ith(context, i_batch[i]) - var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab)) + llama_sampling_set_logits(smpl, logits) - for token_id in 0 ..< n_vocab { - candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0)) - } + llama_sampling_top_k(smpl, nil) + llama_sampling_top_p(smpl, nil) + llama_sampling_temp (smpl, nil) - var candidates_p: llama_token_data_array = .init( - data: &candidates, - size: candidates.count, - sorted: false - ) + let new_token_id = llama_sampling_sample_dist(smpl, nil) - let top_k: Int32 = 40 - let top_p: Float = 0.9 - let temp: Float = 0.4 - - llama_sample_top_k(context, &candidates_p, top_k, 1) - llama_sample_top_p(context, &candidates_p, top_p, 1) - llama_sample_temp(context, &candidates_p, temp) - - let new_token_id = llama_sample_token(context, &candidates_p) - - // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + // const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil); // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { @@ -212,7 +208,7 @@ let t_main_end = ggml_time_us() print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n") -llama_print_timings(context) +llama_print_timings(context, smpl) private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 53fbfb0a8..4dfa19ce8 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -2,7 +2,6 @@ #include "llama.h" #include -#include #include #include #include @@ -65,6 +64,15 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); + auto sparams = llama_sampling_default_params(); + + sparams.seed = params.sparams.seed; + sparams.top_k = 40; + sparams.top_p = 0.9f; + sparams.temp = 0.4f; + + llama_sampling * smpl = llama_sampling_init(model, sparams); + if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; @@ -164,29 +172,17 @@ int main(int argc, char ** argv) { continue; } - auto n_vocab = llama_n_vocab(model); - auto * logits = llama_get_logits_ith(ctx, i_batch[i]); + const auto * logits = llama_get_logits_ith(ctx, i_batch[i]); - std::vector candidates; - candidates.reserve(n_vocab); + llama_sampling_set_logits(smpl, logits); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } + llama_sampling_top_k(smpl, nullptr); + llama_sampling_top_p(smpl, nullptr); + llama_sampling_temp (smpl, nullptr); - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr); - const int top_k = 40; - const float top_p = 0.9f; - const float temp = 0.4f; - - llama_sample_top_k(ctx, &candidates_p, top_k, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_temp (ctx, &candidates_p, temp); - - const llama_token new_token_id = llama_sample_token(ctx, &candidates_p); - - //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + //const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -244,12 +240,13 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); fprintf(stderr, "\n"); llama_batch_free(batch); + llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index b05aa006e..4b288b460 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -90,13 +90,7 @@ int main(int argc, char ** argv) { print_build_info(); - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); + LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); llama_backend_init(); llama_numa_init(params.numa); @@ -314,7 +308,7 @@ int main(int argc, char ** argv) { } // clean up - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_batch_free(batch); llama_free(ctx); llama_free_model(model); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 5e89988e2..166ca4b7d 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -151,8 +151,6 @@ int main(int argc, char ** argv) { print_build_info(); - std::mt19937 rng(params.seed); - llama_backend_init(); llama_numa_init(params.numa); @@ -183,7 +181,7 @@ int main(int argc, char ** argv) { return 1; } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); llama_free_model(model); diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 48a705e15..f439c0c56 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -1,9 +1,5 @@ -#define LLAMA_API_INTERNAL - -#include "grammar-parser.h" -#include "ggml.h" -#include "llama.h" #include "unicode.h" +#include "llama-grammar.h" #include #include @@ -12,22 +8,21 @@ #include #include -static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { - auto decoded = decode_utf8(input_str, {}); - const auto & code_points = decoded.first; +static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { + const auto cpts = unicode_cpts_from_utf8(input_str); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); size_t pos = 0; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + for (const auto & cpt : cpts) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, prev_stacks, *it, cur_stacks); + cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); if (cur_stacks.empty()) { error_pos = pos; - error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'"; + error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; cur_stacks = prev_stacks; return false; } @@ -85,27 +80,7 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } - // Parse the GBNF grammar - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); - - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - fprintf(stdout, "%s: failed to parse grammar\n", __func__); - return 1; - } - - // Ensure that there is a "root" node. - if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) { - fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__); - return 1; - } - - std::vector grammar_rules(parsed_grammar.c_rules()); - - // Create the LLAMA grammar - auto grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); } @@ -122,7 +97,7 @@ int main(int argc, char** argv) { // Validate the input string against the grammar size_t error_pos; std::string error_msg; - bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg); + bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg); if (is_valid) { fprintf(stdout, "Input string is valid according to the grammar.\n"); @@ -131,7 +106,7 @@ int main(int argc, char** argv) { } // Clean up - llama_grammar_free(grammar); + llama_grammar_free_impl(grammar); return 0; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2c61c2e1e..7d2ae7713 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -9,7 +9,7 @@ static std::vector> encode(llama_context * ctx, const std::vector & sentences, const std::string & instruction) { std::vector> result; - const llama_model * mdl = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); @@ -18,16 +18,16 @@ static std::vector> encode(llama_context * ctx, const std::ve const std::string input_string = instruction + sentences[i]; - std::vector inputs = llama_tokenize(mdl, input_string, true, false); + std::vector inputs = llama_tokenize(model, input_string, true, false); const int32_t n_toks = inputs.size(); // GritLM seems to have EOS = "" // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 - // inputs.push_back(llama_token_eos(mdl)); + // inputs.push_back(llama_token_eos(model)); // we want to ignore instruction tokens for mean pooling - const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size(); + const int32_t n_inst = llama_tokenize(model, instruction, true, false).size(); #ifdef GRIT_DEBUG // debug tokens - should be matching as referenced in the GritLM sample @@ -51,7 +51,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_decode(ctx, batch); // get embedding dimensions - uint64_t n_embd = llama_n_embd(mdl); + uint64_t n_embd = llama_n_embd(model); // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); @@ -92,11 +92,11 @@ static std::vector> encode(llama_context * ctx, const std::ve return result; } -static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) { +static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) { std::string result; - const llama_model * mdl = llama_get_model(ctx); - llama_token eos_token = llama_token_eos(mdl); + const llama_model * model = llama_get_model(ctx); + llama_token eos_token = llama_token_eos(model); llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); @@ -104,28 +104,27 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); - std::vector inputs = llama_tokenize(mdl, prompt, false, true); + std::vector inputs = llama_tokenize(model, prompt, false, true); int32_t i_current_token = 0; while (true) { llama_batch_clear(bat); - auto n_inputs = (int32_t)inputs.size(); - for (int32_t i = 0; i < n_inputs; i++) { - llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + { + const int32_t n_inputs = inputs.size(); + + for (int32_t i = 0; i < n_inputs; i++) { + llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + } } inputs.clear(); llama_decode(ctx, bat); - auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); - auto candidates = std::vector(llama_n_vocab(mdl)); - auto n_candidates = (int32_t)candidates.size(); - for (int32_t token = 0; token < n_candidates; token++) { - candidates[token] = llama_token_data{ token, logits[token], 0.0f }; - } - auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false }; + const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); - llama_token token = llama_sample_token_greedy(ctx, &candidates_p); + llama_sampling_set_logits(smpl, logits); + + llama_token token = llama_sampling_sample_greedy(smpl, nullptr); if (token == eos_token) { break; } @@ -167,10 +166,12 @@ int main(int argc, char * argv[]) { llama_backend_init(); - llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx = llama_new_context_with_model(mdl, cparams); + llama_context * ctx = llama_new_context_with_model(model, cparams); + + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -191,7 +192,7 @@ int main(int argc, char * argv[]) { const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const int n_embd = llama_n_embd(mdl); + const int n_embd = llama_n_embd(model); const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); @@ -208,11 +209,12 @@ int main(int argc, char * argv[]) { // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; - std::string response = generate(ctx, prompt, true); + std::string response = generate(ctx, smpl, prompt, true); } + llama_sampling_free(smpl); llama_free(ctx); - llama_free_model(mdl); + llama_free_model(model); llama_backend_free(); return 0; diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 83b85d72b..1c7f53505 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -638,7 +638,7 @@ int main(int argc, char ** argv) { g_collector.save_imatrix(); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); llama_free_model(model); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 05700c1d5..371232421 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -2,7 +2,6 @@ #include "console.h" #include "llama.h" -#include "grammar-parser.h" #include #include @@ -34,6 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; +static llama_sampling ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -93,7 +93,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx); + llama_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -103,7 +103,6 @@ static void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; - llama_sampling_params & sparams = params.sparams; g_params = ¶ms; if (!gpt_params_parse(argc, argv, params)) { @@ -111,6 +110,8 @@ int main(int argc, char ** argv) { return 1; } + auto & sparams = params.sparams; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("infill", "log")); LOG_TEE("Log start\n"); @@ -156,26 +157,21 @@ int main(int argc, char ** argv) { LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); } - LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); + print_build_info(); - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - LOG_TEE("%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); + LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); LOG("%s: llama backend init\n", __func__); llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; + llama_model * model = nullptr; + llama_context * ctx = nullptr; + llama_sampling * smpl = nullptr; g_model = &model; g_ctx = &ctx; + g_smpl = &smpl; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -305,7 +301,7 @@ int main(int argc, char ** argv) { LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); } } - LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); + LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); @@ -349,7 +345,7 @@ int main(int argc, char ** argv) { std::vector embd; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + smpl = llama_sampling_init(model, sparams); while (n_remain != 0 || params.interactive) { // predict @@ -421,11 +417,11 @@ int main(int argc, char ** argv) { embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr); + const llama_token id = llama_sampling_sample(smpl, ctx, -1); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(smpl, id, true); - LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(id); @@ -444,7 +440,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + llama_sampling_accept(smpl, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -476,7 +472,7 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){ + if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { // print an eot token printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); @@ -542,7 +538,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of generation tokens in interactive mode - else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { + else if (llama_token_is_eog(model, llama_sampling_last(smpl))) { LOG("found EOS token\n"); if (params.interactive) { @@ -615,7 +611,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(ctx_sampling); + llama_sampling_reset(smpl); } is_interacting = false; } @@ -638,13 +634,13 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); llama_free_model(model); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_backend_free(); #ifndef LOG_DISABLE_LOGS diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8edadef90..d4201d93b 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1574,7 +1574,7 @@ int main(int argc, char ** argv) { fflush(p_err->fout); } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2aafe2316..c33f55f72 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -120,8 +120,8 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo LOGi("Using %d threads", n_threads); llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; + + ctx_params.n_ctx = 2048; ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; @@ -380,11 +380,13 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( JNIEnv * env, jobject, jlong context_pointer, + jlong sampling_pointer, jlong batch_pointer, jint n_len, jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); + const auto sampling = reinterpret_cast(sampling_pointer); const auto batch = reinterpret_cast(batch_pointer); const auto model = llama_get_model(context); @@ -392,20 +394,12 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); - auto n_vocab = llama_n_vocab(model); - auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); + const auto * logits = llama_get_logits_ith(context, batch->n_tokens - 1); - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sampling_set_logits(sampling, logits); // sample the most likely token - const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); + const auto new_token_id = llama_sampling_sample_greedy(sampling, nullptr); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 48b7840ae..515170f67 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama actor LlamaContext { private var model: OpaquePointer private var context: OpaquePointer + private var sampling: OpaquePointer private var batch: llama_batch private var tokens_list: [llama_token] var is_done: Bool = false @@ -42,9 +43,11 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] + self.sampling = llama_sampling_init(context, llama_sampling_default_params()) } deinit { + llama_sampling_free(sampling) llama_batch_free(batch) llama_free(context) llama_free_model(model) @@ -69,7 +72,6 @@ actor LlamaContext { print("Using \(n_threads) threads") var ctx_params = llama_context_default_params() - ctx_params.seed = 1234 ctx_params.n_ctx = 2048 ctx_params.n_threads = Int32(n_threads) ctx_params.n_threads_batch = Int32(n_threads) @@ -147,17 +149,9 @@ actor LlamaContext { let n_vocab = llama_n_vocab(model) let logits = llama_get_logits_ith(context, batch.n_tokens - 1) - var candidates = Array() - candidates.reserveCapacity(Int(n_vocab)) + llama_sampling_set_logits(sampling, logits); - for token_id in 0..sparams); - if (!ctx_sampling) { + struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { - const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); response += tmp; if (strcmp(tmp, "") == 0) break; if (strstr(tmp, "###")) break; // Yi-VL behavior @@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); printf("\n"); } @@ -310,7 +310,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); @@ -327,7 +327,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index f500ea5b9..c041fe530 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e LOG_TEE("%s: image token past: %d\n", __func__, n_past); } -static const char * sample(struct llama_sampling_context * ctx_sampling, +static const char * sample(struct llama_sampling * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); - llama_sampling_accept(ctx_sampling, ctx_llama, id, true); + const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1); + llama_sampling_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri return ctx_llava; } -static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ +static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ std::string user_prompt = prompt; int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); if (!is_first) { @@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); - return ctx_sampling; + struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + return smpl; } -static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){ +static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){ - const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); return tmp; } @@ -278,12 +278,12 @@ int main(int argc, char ** argv) { if (!params.prompt.empty()) { LOG_TEE("%s\n", params.prompt.c_str()); LOG_TEE(""); - auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); + auto smpl = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; std::string response = ""; bool have_tmp = false; for (int i = 0; i < max_tgt_len; i++) { - auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + auto tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0){ if(!have_tmp)continue; @@ -296,18 +296,18 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); }else { while (true) { LOG_TEE(""); std::string prompt; std::getline(std::cin, prompt); LOG_TEE(""); - auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true); + auto smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true); const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { - auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + auto tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0) break; if (strstr(tmp, "###")) break; // Yi-VL behavior @@ -315,11 +315,11 @@ int main(int argc, char ** argv) { if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); } } printf("\n"); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 81cf1629c..2bd31d002 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -1,7 +1,6 @@ #include "common.h" #include "llama.h" -#include #include #include #include @@ -118,7 +117,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); // verification n-grams std::vector ngrams_cur(G); @@ -159,9 +158,9 @@ int main(int argc, char ** argv) { // sample first token { - id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + id = llama_sampling_sample(smpl, ctx, 0); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(smpl, id, true); { const std::string token_str = llama_token_to_piece(ctx, id); @@ -284,9 +283,9 @@ int main(int argc, char ** argv) { } // sample the next token - id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); + id = llama_sampling_sample(smpl, ctx, i_batch); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(smpl, id, true); // print { @@ -361,7 +360,7 @@ int main(int argc, char ** argv) { if (v == 0) { // sample from the last level for (int i = 0; i < W; i++) { - tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); } } else { for (int i = 0; i < W; i++) { @@ -468,10 +467,10 @@ int main(int argc, char ** argv) { LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept); - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); llama_kv_cache_view_free(&kvc_view); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_batch_free(batch); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index d53a9828c..da4d57a51 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -3,13 +3,11 @@ #include "common.h" #include "ngram-cache.h" -#include #include #include #include #include #include -#include int main(int argc, char ** argv){ gpt_params params; @@ -106,7 +104,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); std::vector draft; @@ -130,9 +128,9 @@ int main(int argc, char ** argv){ int i_dft = 0; while (true) { // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); + llama_token id = llama_sampling_sample(smpl, ctx, i_dft); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(smpl, id, true); const std::string token_str = llama_token_to_piece(ctx, id); @@ -241,9 +239,9 @@ int main(int argc, char ** argv){ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_batch_free(batch_tgt); llama_free(ctx); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c55efbb66..296c1c687 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -33,6 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; +static llama_sampling ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -105,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx); + llama_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, std::string role, std::string content) { llama_chat_msg new_msg{role, content}; - auto formatted = llama_chat_format_single( - model, g_params->chat_template, chat_msgs, new_msg, role == "user"); + auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); chat_msgs.push_back({role, content}); LOG("formatted: %s\n", formatted.c_str()); return formatted; @@ -137,7 +137,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling_params & sparams = params.sparams; + auto & sparams = params.sparams; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("main", "log")); @@ -183,27 +183,23 @@ int main(int argc, char ** argv) { LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); } - LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); + print_build_info(); - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - LOG_TEE("%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); + LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); LOG("%s: llama backend init\n", __func__); llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; - llama_context * ctx_guidance = NULL; + llama_model * model = nullptr; + llama_context * ctx = nullptr; + llama_sampling * smpl = nullptr; + std::vector chat_msgs; + g_model = &model; g_ctx = &ctx; + g_smpl = &smpl; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -211,10 +207,6 @@ int main(int argc, char ** argv) { model = llama_init.model; ctx = llama_init.context; - if (sparams.cfg_scale > 1.f) { - struct llama_context_params lparams = llama_context_params_from_gpt_params(params); - ctx_guidance = llama_new_context_with_model(model, lparams); - } if (model == NULL) { LOG_TEE("%s: error: unable to load model\n", __func__); @@ -251,9 +243,6 @@ int main(int argc, char ** argv) { } llama_attach_threadpool(ctx, threadpool, threadpool_batch); - if (ctx_guidance) { - llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch); - } const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); @@ -337,24 +326,6 @@ int main(int argc, char ** argv) { } // Tokenize negative prompt - std::vector guidance_inp; - int guidance_offset = 0; - int original_prompt_len = 0; - if (ctx_guidance) { - LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - - guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true); - LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str()); - - std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); - LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - - original_prompt_len = original_inp.size(); - guidance_offset = (int)guidance_inp.size() - original_prompt_len; - LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); - LOG("guidance_offset: %s", log_tostr(guidance_offset)); - } - if ((int) embd_inp.size() > n_ctx - 4) { LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); return 1; @@ -421,15 +392,6 @@ int main(int argc, char ** argv) { LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); } - if (ctx_guidance) { - LOG_TEE("\n"); - LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); - LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); - for (int i = 0; i < (int) guidance_inp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); - } - } - if (params.n_keep > add_bos) { LOG_TEE("%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { @@ -495,8 +457,8 @@ int main(int argc, char ** argv) { } } } - LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); - LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); + LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str()); + LOG_TEE("sampling order: \n%s\n", sparams.print_samplers().c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); // group-attention state @@ -543,7 +505,6 @@ int main(int argc, char ** argv) { int n_remain = params.n_predict; int n_consumed = 0; int n_session_consumed = 0; - int n_past_guidance = 0; std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; @@ -555,7 +516,6 @@ int main(int argc, char ** argv) { display = params.display_prompt; std::vector embd; - std::vector embd_guidance; // tokenized antiprompts std::vector> antiprompt_ids; @@ -565,8 +525,8 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); - if (!ctx_sampling) { + smpl = llama_sampling_init(model, sparams); + if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } @@ -612,7 +572,7 @@ int main(int argc, char ** argv) { // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() + std::max(0, guidance_offset) >= n_ctx) { + if (n_past + (int) embd.size() >= n_ctx) { if (params.n_predict == -2) { LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); break; @@ -629,11 +589,7 @@ int main(int argc, char ** argv) { n_past -= n_discard; - if (ctx_guidance) { - n_past_guidance -= n_discard; - } - - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + LOG("after swap: n_past = %d\n", n_past); LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); @@ -686,46 +642,6 @@ int main(int argc, char ** argv) { } } - // evaluate tokens in batches - // embd is typically prepared beforehand to fit within a batch, but not always - if (ctx_guidance) { - int input_size = 0; - llama_token * input_buf = NULL; - - if (n_past_guidance < (int) guidance_inp.size()) { - // Guidance context should have the same data with these modifications: - // - // * Replace the initial prompt - // * Shift everything by guidance_offset - embd_guidance = guidance_inp; - if (embd.begin() + original_prompt_len < embd.end()) { - embd_guidance.insert( - embd_guidance.end(), - embd.begin() + original_prompt_len, - embd.end() - ); - } - - input_buf = embd_guidance.data(); - input_size = embd_guidance.size(); - - LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str()); - } else { - input_buf = embd.data(); - input_size = embd.size(); - } - - for (int i = 0; i < input_size; i += params.n_batch) { - int n_eval = std::min(input_size - i, params.n_batch); - if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { - LOG_TEE("%s : failed to eval\n", __func__); - return 1; - } - - n_past_guidance += n_eval; - } - } - for (int i = 0; i < (int) embd.size(); i += params.n_batch) { int n_eval = (int) embd.size() - i; if (n_eval > params.n_batch) { @@ -755,7 +671,6 @@ int main(int argc, char ** argv) { } embd.clear(); - embd_guidance.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // optionally save the session on first sample (for faster prompt loading next time) @@ -766,11 +681,11 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + const llama_token id = llama_sampling_sample(smpl, ctx, -1); - llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); + llama_sampling_accept(smpl, id, /* apply_grammar= */ true); - LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(id); @@ -789,7 +704,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false); + llama_sampling_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -832,7 +747,7 @@ int main(int argc, char ** argv) { // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { const int n_prev = 32; - const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev); + const std::string last_output = llama_sampling_prev_str(smpl, ctx, n_prev); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -854,7 +769,7 @@ int main(int argc, char ** argv) { } // check for reverse prompt using special tokens - llama_token last_token = llama_sampling_last(ctx_sampling); + llama_token last_token = llama_sampling_last(smpl); for (std::vector ids : antiprompt_ids) { if (ids.size() == 1 && last_token == ids[0]) { if (params.interactive) { @@ -871,7 +786,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { + if (llama_token_is_eog(model, llama_sampling_last(smpl))) { LOG("found an EOG token\n"); if (params.interactive) { @@ -892,7 +807,7 @@ int main(int argc, char ** argv) { // if current token is not EOG, we add it to current assistant message if (params.conversation) { - auto id = llama_sampling_last(ctx_sampling); + auto id = llama_sampling_last(smpl); assistant_ss << llama_token_to_piece(ctx, id, false); } @@ -988,7 +903,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(ctx_sampling); + llama_sampling_reset(smpl); } is_interacting = false; } @@ -1013,14 +928,13 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); - if (ctx_guidance) { llama_free(ctx_guidance); } llama_free(ctx); llama_free_model(model); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_backend_free(); ggml_threadpool_free(threadpool); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 621a1c959..7ce982a92 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -50,8 +50,8 @@ static std::vector k_prompts = { struct client { ~client() { - if (ctx_sampling) { - llama_sampling_free(ctx_sampling); + if (smpl) { + llama_sampling_free(smpl); } } @@ -72,7 +72,7 @@ struct client { std::string prompt; std::string response; - struct llama_sampling_context * ctx_sampling = nullptr; + struct llama_sampling * smpl = nullptr; }; static void print_date_time() { @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(params.sparams); + client.smpl = llama_sampling_init(model, params.sparams); } std::vector tokens_system; @@ -253,7 +253,7 @@ int main(int argc, char ** argv) { client.prompt = client.input + "\nAssistant:"; client.response = ""; - llama_sampling_reset(client.ctx_sampling); + llama_sampling_reset(client.smpl); // do not prepend BOS because we have a system prompt! std::vector tokens_prompt; @@ -341,9 +341,9 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); + const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i); - llama_sampling_accept(client.ctx_sampling, ctx, id, true); + llama_sampling_accept(client.smpl, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -371,7 +371,7 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); @@ -413,7 +413,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); - llama_print_timings(ctx); + // TODO: print sampling/grammar timings for all clients + llama_print_timings(ctx, nullptr); llama_batch_free(batch); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1..0992ccc3c 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -26,8 +26,6 @@ int main(int argc, char ** argv) { return 1; } - srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed); - int n_junk = params.n_junk; int n_keep = params.n_keep; int n_grp = params.grp_attn_n; @@ -80,12 +78,13 @@ int main(int argc, char ** argv) { GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); llama_context * ctx = llama_new_context_with_model(model, ctx_params); - if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; } + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + // tokenize the prompt std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); @@ -217,20 +216,12 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - auto n_vocab = llama_n_vocab(model); - auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sampling_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { @@ -267,12 +258,13 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); fprintf(stderr, "\n"); llama_batch_free(batch); + llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 484dd5891..987236ab6 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -2007,13 +2007,7 @@ int main(int argc, char ** argv) { print_build_info(); - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); + LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); llama_backend_init(); llama_numa_init(params.numa); @@ -2054,7 +2048,7 @@ int main(int argc, char ** argv) { results = perplexity(ctx, params, n_ctx); } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); write_logfile(ctx, params, model, results); llama_free(ctx); diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 68cf8d359..498cbbe3c 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,7 +1,7 @@ -#define LLAMA_API_INTERNAL #include "common.h" #include "ggml.h" #include "llama.h" +#include "llama-impl.h" #include #include @@ -319,8 +319,7 @@ int main(int argc, char ** argv) { } auto cparams = llama_context_default_params(); - cparams.n_ctx = 256; - cparams.seed = 1; + cparams.n_ctx = 256; ctx = llama_new_context_with_model(model, cparams); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index aab9d8105..6089344a0 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -294,8 +294,8 @@ int main(int argc, char ** argv) { } // clean up + llama_print_timings(ctx, nullptr); llama_batch_free(query_batch); - llama_print_timings(ctx); llama_free(ctx); llama_free_model(model); llama_backend_free(); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 3ea7c790d..02f7a93eb 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -3,12 +3,12 @@ #include #include -#include int main(int argc, char ** argv) { gpt_params params; params.prompt = "The quick brown fox"; + params.sparams.seed = 1234; if (!gpt_params_parse(argc, argv, params)) { gpt_params_print_usage(argc, argv, params); @@ -38,6 +38,11 @@ int main(int argc, char ** argv) { return 1; } + llama_sampling_params sparams = llama_sampling_default_params(); + sparams.seed = params.sparams.seed; + + llama_sampling * smpl = llama_sampling_init(model, sparams); + // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -64,16 +69,11 @@ int main(int argc, char ** argv) { printf("\nfirst run: %s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - auto * logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(model); + const auto * logits = llama_get_logits(ctx); - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx, &candidates_p); + llama_sampling_set_logits(smpl, logits); + + auto next_token = llama_sampling_sample_dist(smpl, nullptr); auto next_token_str = llama_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -96,6 +96,8 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + llama_sampling * smpl2 = llama_sampling_init(model, sparams); + printf("\nsecond run: %s", params.prompt.c_str()); // load state (rng, logits, embedding and kv_cache) from file @@ -124,15 +126,11 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - auto * logits = llama_get_logits(ctx2); - auto n_vocab = llama_n_vocab(model); - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx2, &candidates_p); + const auto * logits = llama_get_logits(ctx2); + + llama_sampling_set_logits(smpl2, logits); + + auto next_token = llama_sampling_sample_dist(smpl2, nullptr); auto next_token_str = llama_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -157,7 +155,9 @@ int main(int argc, char ** argv) { } // make new context - auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + + llama_sampling * smpl3 = llama_sampling_init(model, sparams); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -215,15 +215,11 @@ int main(int argc, char ** argv) { // third run with seq 1 instead of 0 for (auto i = 0; i < params.n_predict; i++) { - auto * logits = llama_get_logits(ctx3); - auto n_vocab = llama_n_vocab(model); - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx3, &candidates_p); + const auto * logits = llama_get_logits(ctx3); + + llama_sampling_set_logits(smpl3, logits); + + auto next_token = llama_sampling_sample_dist(smpl3, nullptr); auto next_token_str = llama_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); @@ -240,6 +236,10 @@ int main(int argc, char ** argv) { printf("\n"); + llama_sampling_free(smpl); + llama_sampling_free(smpl2); + llama_sampling_free(smpl3); + llama_free(ctx3); llama_free_model(model); diff --git a/examples/server/README.md b/examples/server/README.md index 805e05b4a..37024dea0 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -470,8 +470,6 @@ node index.js `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. - `penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`. - `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0` @@ -724,7 +722,6 @@ Example: "stopping_word": "" }, "penalize_nl": true, - "penalty_prompt_tokens": [], "presence_penalty": 0.0, "prompt": "Say hello to llama.cpp", "repeat_last_n": 64, @@ -748,8 +745,7 @@ Example: "tfs_z": 1.0, "top_k": 40, "top_p": 0.949999988079071, - "typical_p": 1.0, - "use_penalty_prompt_tokens": false + "typical_p": 1.0 } ] ``` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 109dbc023..f292c2cd6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3,7 +3,6 @@ #include "common.h" #include "json-schema-to-grammar.h" #include "llama.h" -#include "grammar-parser.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -173,11 +172,13 @@ struct server_slot { std::string stopping_word; // sampling - llama_token sampled; - struct llama_sampling_params sparams; - llama_sampling_context * ctx_sampling = nullptr; json json_schema; + struct gpt_sampling_params sparams; + + llama_token sampled; + llama_sampling * smpl = nullptr; + int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width @@ -636,8 +637,8 @@ struct server_context { // Clear any sampling context for (server_slot & slot : slots) { - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); + if (slot.smpl != nullptr) { + llama_sampling_free(slot.smpl); } } @@ -864,8 +865,8 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, const server_task & task) { slot_params default_params; // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - llama_sampling_params default_sparams = params.sparams; - auto & data = task.data; + auto default_sparams = params.sparams; + const auto & data = task.data; if (data.count("__oaicompat") != 0) { slot.oaicompat = true; @@ -882,7 +883,7 @@ struct server_context { slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); @@ -904,7 +905,8 @@ struct server_context { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); return false; - } else if (data.contains("json_schema") && !data.contains("grammar")) { + } + if (data.contains("json_schema") && !data.contains("grammar")) { try { auto schema = json_value(data, "json_schema", json::object()); slot.sparams.grammar = json_schema_to_grammar(schema); @@ -954,56 +956,11 @@ struct server_context { } } - // penalize user-provided tokens - { - slot.sparams.penalty_prompt_tokens.clear(); - slot.sparams.use_penalty_prompt_tokens = false; - - const auto & penalty_prompt = data.find("penalty_prompt"); - - if (penalty_prompt != data.end()) { - if (penalty_prompt->is_string()) { - const auto penalty_prompt_string = penalty_prompt->get(); - slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false); - - if (slot.params.n_predict > 0) { - slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict); - } - slot.sparams.use_penalty_prompt_tokens = true; - - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); - } - else if (penalty_prompt->is_array()) { - const auto n_tokens = penalty_prompt->size(); - slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); - - const int n_vocab = llama_n_vocab(model); - for (const auto & penalty_token : *penalty_prompt) { - if (penalty_token.is_number_integer()) { - const auto tok = penalty_token.get(); - if (tok >= 0 && tok < n_vocab) { - slot.sparams.penalty_prompt_tokens.push_back(tok); - } - } - } - slot.sparams.use_penalty_prompt_tokens = true; - - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); - } - } - } - { slot.sparams.logit_bias.clear(); if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY}); } const auto & logit_bias = data.find("logit_bias"); @@ -1024,12 +981,12 @@ struct server_context { if (el[0].is_number_integer()) { llama_token tok = el[0].get(); if (tok >= 0 && tok < n_vocab) { - slot.sparams.logit_bias[tok] = bias; + slot.sparams.logit_bias.push_back({tok, bias}); } } else if (el[0].is_string()) { auto toks = llama_tokenize(model, el[0].get(), false); for (auto tok : toks) { - slot.sparams.logit_bias[tok] = bias; + slot.sparams.logit_bias.push_back({tok, bias}); } } } @@ -1051,26 +1008,27 @@ struct server_context { } { - const auto & samplers_sequence = data.find("samplers"); - if (samplers_sequence != data.end() && samplers_sequence->is_array()) { + const auto & samplers = data.find("samplers"); + if (samplers != data.end() && samplers->is_array()) { std::vector sampler_names; - for (const auto & sampler_name : *samplers_sequence) { + for (const auto & sampler_name : *samplers) { if (sampler_name.is_string()) { sampler_names.emplace_back(sampler_name); } } - slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false); + slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false); } else { - slot.sparams.samplers_sequence = default_sparams.samplers_sequence; + slot.sparams.samplers = default_sparams.samplers; } } { - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); + if (slot.smpl != nullptr) { + llama_sampling_free(slot.smpl); } - slot.ctx_sampling = llama_sampling_init(slot.sparams); - if (slot.ctx_sampling == nullptr) { + + slot.smpl = llama_sampling_init(model, slot.sparams); + if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; @@ -1159,11 +1117,6 @@ struct server_context { slot.generated_text += token_str; slot.has_next_token = true; - if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { - // we can change penalty_prompt_tokens because it is always created from scratch each request - slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); - } - // check if there is incomplete UTF-8 character at the end bool incomplete = false; for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { @@ -1281,13 +1234,10 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - - std::vector samplers_sequence; - samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); - for (const auto & sampler_type : slot.sparams.samplers_sequence) { - samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type)); + std::vector samplers; + samplers.reserve(slot.sparams.samplers.size()); + for (const auto & sampler : slot.sparams.samplers) { + samplers.emplace_back(llama_sampling_type_to_str(sampler)); } return json { @@ -1302,13 +1252,11 @@ struct server_context { {"top_p", slot.sparams.top_p}, {"min_p", slot.sparams.min_p}, {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, + {"typical_p", slot.sparams.typ_p}, {"repeat_last_n", slot.sparams.penalty_last_n}, {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, - {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, - {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, @@ -1317,13 +1265,13 @@ struct server_context { {"max_tokens", slot.params.n_predict}, // User configured n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, - {"ignore_eos", ignore_eos}, + {"ignore_eos", slot.sparams.ignore_eos}, {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, + //{"logit_bias", slot.sparams.logit_bias}, {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence} + {"samplers", samplers}, }; } @@ -2139,7 +2087,7 @@ struct server_context { GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - llama_sampling_reset(slot.ctx_sampling); + llama_sampling_reset(slot.smpl); if (!slot.params.cache_prompt) { slot.n_past_se = 0; @@ -2152,7 +2100,7 @@ struct server_context { // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { - llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false); } } } @@ -2205,7 +2153,7 @@ struct server_context { slot.n_past_se = 0; slot.ga_i = 0; // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(slot.ctx_sampling); + llama_sampling_reset(slot.smpl); } // remove the non-common part from the cache @@ -2382,9 +2330,9 @@ struct server_context { } completion_token_output result; - const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i); - llama_sampling_accept(slot.ctx_sampling, ctx, id, true); + llama_sampling_accept(slot.smpl, id, true); slot.n_decoded += 1; if (slot.n_decoded == 1) { @@ -2393,34 +2341,17 @@ struct server_context { metrics.on_prompt_eval(slot); } - llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; - const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); - if (n_probs > 0) { - const size_t n_valid = slot.ctx_sampling->n_valid; + const auto * cur_p = llama_sampling_get_candidates(slot.smpl); - // Make sure at least n_probs top tokens are at the front of the vector: - if (slot.sparams.temp == 0.0f && n_probs > n_valid) { - llama_sample_top_k(ctx, &cur_p, n_probs, 0); - } - - if (slot.sparams.temp == 0.0f) { - // With greedy sampling the probabilities have possibly not been calculated. - for (size_t i = 0; i < n_probs; ++i) { - result.probs.push_back({ - cur_p.data[i].id, - i == 0 ? 1.0f : 0.0f - }); - } - } else { - for (size_t i = 0; i < n_probs; ++i) { - result.probs.push_back({ - cur_p.data[i].id, - i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. - }); - } - } + // TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643 + // fix if necessary + for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { + result.probs.push_back({ + cur_p->data[i].id, + i >= cur_p->size ? 0.0f : cur_p->data[i].p, + }); } if (!process_token(result, slot)) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 69a92cf7d..674158b85 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,6 +55,8 @@ int main(int argc, char ** argv) { return 1; } + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + // tokenize the prompt std::vector tokens_list; @@ -110,20 +112,12 @@ int main(int argc, char ** argv) { while (n_cur <= n_predict) { // sample the next token { - auto n_vocab = llama_n_vocab(model); - auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sampling_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -160,12 +154,13 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); fprintf(stderr, "\n"); llama_batch_free(batch); + llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 1616edecb..e95073366 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -21,7 +21,7 @@ struct seq_draft { std::vector tokens; std::vector> dists; - struct llama_sampling_context * ctx_sampling; + struct llama_sampling * smpl; }; int main(int argc, char ** argv) { @@ -37,16 +37,16 @@ int main(int argc, char ** argv) { return 1; } + // for probabilities to be computed even with temp = 0 + params.sparams.n_probs = 16; + // max number of parallel drafting sequences (i.e. tree branches) const int n_seq_dft = params.n_parallel; // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - std::default_random_engine rng(params.seed); + std::default_random_engine rng(params.sparams.seed); std::uniform_real_distribution<> u_dist; #ifndef LOG_DISABLE_LOGS @@ -179,19 +179,15 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; - // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + // target model sampling context (reuse the llama_context's sampling instance) + struct llama_sampling * smpl = llama_sampling_init(model_tgt, params.sparams); // draft sequence data std::vector drafts(n_seq_dft); - params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar - if (params.sparams.temp == 0) { - params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model - } - for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].ctx_sampling = llama_sampling_init(params.sparams); + // allocate llama_sampling for each draft sequence + drafts[s].smpl = llama_sampling_init(model_dft, params.sparams); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); @@ -234,9 +230,15 @@ int main(int argc, char ** argv) { if (params.sparams.temp > 0) { // stochastic verification - llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); - llama_sample_softmax(ctx_tgt, &dist_tgt); - float p_tgt = 0, p_dft = 0; + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft])); + + auto & dist_tgt = *llama_sampling_get_candidates(smpl); + + llama_sampling_grammar(smpl, &dist_tgt); + llama_sampling_softmax(smpl, &dist_tgt); + + float p_tgt = 0.0f; + float p_dft = 0.0f; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); @@ -278,7 +280,7 @@ int main(int argc, char ** argv) { accept = true; token_id = drafts[s].tokens[i_dft]; token_str = llama_token_to_piece(ctx_tgt, token_id); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + llama_sampling_accept(smpl, token_id, true); LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); break; @@ -332,8 +334,8 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = llama_sample_token(ctx_tgt, &dist_tgt); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + token_id = llama_sampling_sample_dist(smpl, &dist_tgt); + llama_sampling_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -342,11 +344,11 @@ int main(int argc, char ** argv) { // sample from the target model LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = llama_sampling_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + llama_sampling_accept(smpl, token_id, true); - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str()); token_str = llama_token_to_piece(ctx_tgt, token_id); @@ -434,7 +436,7 @@ int main(int argc, char ** argv) { break; } - llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling); + llama_sampling_cp(smpl, drafts[0].smpl); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -463,20 +465,20 @@ int main(int argc, char ** argv) { continue; } - llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); + llama_sampling_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); - const auto & cur_p = drafts[s].ctx_sampling->cur; + const auto * cur_p = llama_sampling_get_candidates(drafts[s].smpl); - for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) { + for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) { LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); + k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } std::vector sa(1, s); // attempt to split the branch if the probability is high enough for (int f = 1; f < 8; ++f) { - if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { + if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) { LOG("splitting seq %3d into %3d\n", s, n_seq_cur); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); @@ -503,7 +505,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); + llama_sampling_cp(drafts[s].smpl, drafts[n_seq_cur].smpl); sa.push_back(n_seq_cur); @@ -515,15 +517,15 @@ int main(int argc, char ** argv) { // add drafted token for each sequence for (int is = 0; is < (int) sa.size(); ++is) { - const llama_token id = cur_p[is].id; + const llama_token id = cur_p->data[is].id; const int s = sa[is]; - llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); + llama_sampling_accept(drafts[s].smpl, id, true); drafts[s].tokens.push_back(id); // save cur_p.data into drafts[s].dists - drafts[s].dists.push_back(cur_p); + drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size}); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -594,14 +596,15 @@ int main(int argc, char ** argv) { LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ndraft:\n"); - llama_print_timings(ctx_dft); + // TODO: print sampling/grammar timings for all drafts + llama_print_timings(ctx_dft, nullptr); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx_tgt); + llama_print_timings(ctx_tgt, smpl); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); for (int s = 0; s < n_seq_dft; ++s) { - llama_sampling_free(drafts[s].ctx_sampling); + llama_sampling_free(drafts[s].smpl); } llama_batch_free(batch_dft); diff --git a/include/llama.h b/include/llama.h index bfc37e88b..6feebba94 100644 --- a/include/llama.h +++ b/include/llama.h @@ -33,16 +33,21 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF +// TODO: use everywhere in the implementation +#define LLAMA_TOKEN_NULL -1 + #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 8 +#define LLAMA_SESSION_VERSION 9 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_MAX_SAMPLERS 16 + #ifdef __cplusplus extern "C" { #endif @@ -53,8 +58,10 @@ extern "C" { // TODO: show sample usage // + // struct llama_vocab; // TODO: add in the future struct llama_model; struct llama_context; + struct llama_sampling; typedef int32_t llama_pos; typedef int32_t llama_token; @@ -199,6 +206,16 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; + enum llama_sampler_type { + LLAMA_SAMPLER_TYPE_NONE = 0, + LLAMA_SAMPLER_TYPE_TOP_K = 1, + LLAMA_SAMPLER_TYPE_TOP_P = 2, + LLAMA_SAMPLER_TYPE_MIN_P = 3, + LLAMA_SAMPLER_TYPE_TFS_Z = 4, + LLAMA_SAMPLER_TYPE_TYPICAL_P = 5, + LLAMA_SAMPLER_TYPE_TEMPERATURE = 6, + }; + typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -206,6 +223,7 @@ extern "C" { } llama_token_data; typedef struct llama_token_data_array { + // TODO: consider SoA llama_token_data * data; size_t size; bool sorted; @@ -300,7 +318,6 @@ extern "C" { // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // https://github.com/ggerganov/llama.cpp/pull/7544 struct llama_context_params { - uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size @@ -328,7 +345,8 @@ extern "C" { enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] - // Keep the booleans together to avoid misalignment during copy-by-value. + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + // TODO: move at the end of the struct bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU @@ -356,53 +374,56 @@ extern "C" { void * kv_overrides; // pointer to vector containing overrides } llama_model_quantize_params; - // grammar types - struct llama_grammar; + typedef struct llama_logit_bias { + llama_token token; + float bias; + } llama_logit_bias; - // grammar element type - enum llama_gretype { - // end of rule definition - LLAMA_GRETYPE_END = 0, + // parameters for sampling the logits + typedef struct llama_sampling_params { + uint32_t seed; // the seed used to initialize llama_sampling_context + int32_t n_prev; // number of previous tokens to remember + int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k; // <= 0 to use vocab size + float top_p; // 1.0 = disabled + float min_p; // 0.0 = disabled + float tfs_z; // 1.0 = disabled + float typ_p; // typical_p, 1.0 = disabled + float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range; // 0.0 = disabled + float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat; // 1.0 = disabled + float penalty_freq; // 0.0 = disabled + float penalty_present; // 0.0 = disabled + int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau; // target entropy + float mirostat_eta; // learning rate - // start of alternate definition for rule - LLAMA_GRETYPE_ALT = 1, + // samplers + int32_t n_samplers; + enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS]; - // non-terminal element: reference to rule - LLAMA_GRETYPE_RULE_REF = 2, - - // terminal element: character (code point) - LLAMA_GRETYPE_CHAR = 3, - - // inverse char(s) ([^a], [^a-b] [^abc]) - LLAMA_GRETYPE_CHAR_NOT = 4, - - // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to - // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - - // modifies a preceding LLAMA_GRETYPE_CHAR or - // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 6, - - // any character (.) - LLAMA_GRETYPE_CHAR_ANY = 7, - }; - - typedef struct llama_grammar_element { - enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID - } llama_grammar_element; + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + bool penalize_nl; // consider newlines as a repeatable token + bool ignore_eos; // ignore the end-of-sequence token + } llama_sampling_params; // performance timing information struct llama_timings { double t_start_ms; double t_end_ms; double t_load_ms; - double t_sample_ms; + double t_sampling_ms; + double t_grammar_ms; + double t_accept_ms; double t_p_eval_ms; double t_eval_ms; - int32_t n_sample; + int32_t n_sampling; + int32_t n_grammar; + int32_t n_accept; int32_t n_p_eval; int32_t n_eval; }; @@ -417,8 +438,9 @@ extern "C" { struct llama_lora_adapter; // Helpers for getting default parameters - LLAMA_API struct llama_model_params llama_model_default_params(void); - LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_model_params llama_model_default_params(void); + LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_sampling_params llama_sampling_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); // Initialize the llama + ggml backend @@ -445,6 +467,7 @@ extern "C" { LLAMA_API void llama_free_model(struct llama_model * model); + // TODO: rename to llama_init_from_model LLAMA_API struct llama_context * llama_new_context_with_model( struct llama_model * model, struct llama_context_params params); @@ -460,23 +483,22 @@ extern "C" { LLAMA_API bool llama_supports_mlock (void); LLAMA_API bool llama_supports_gpu_offload(void); - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); - LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); - LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_layer (const struct llama_model * model); + LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); @@ -704,7 +726,7 @@ extern "C" { // // Returns the *actual* size in bytes of the state - // (rng, logits, embedding and kv_cache) + // (logits, embedding and kv_cache) // Only use when saving the state, not when restoring it, otherwise the size may be too small. LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), @@ -1006,159 +1028,131 @@ extern "C" { char * buf, int32_t length); - // - // Grammar - // - - /// Initialize a llama_grammar. - /// - /// @param rules The rule elements of the grammar to initialize. - /// @param n_rules The number of rules. - /// @param start_rule_index The index of the root rule (the starting point of the grammar). - /// @return The initialized llama_grammar or nullptr if initialization failed. - LLAMA_API struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); - - LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); - - LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); - - /// @details Apply constraints from grammar - LLAMA_API void llama_grammar_sample( - const struct llama_grammar * grammar, - const struct llama_context * ctx, - llama_token_data_array * candidates); - LLAMA_API DEPRECATED(void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar), - "use llama_grammar_sample instead"); - - /// @details Accepts the sampled token into the grammar - LLAMA_API void llama_grammar_accept_token( - struct llama_grammar * grammar, - struct llama_context * ctx, - llama_token token); - // // Sampling functions // - // Sets the current rng seed. - LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + // TODO: llama_model should become llama_vocab + LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); + + LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); + + // Copies the internal state of the sampler (rng, prev, params, grammar, etc.) + LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); + + // - clear prev token + // - reset grammar state + LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); + + // Sampling parameter mutation + // TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable + LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); + LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); + + // Set the logits from which to sample. + // This call initializes the internal token candidates array. + // The internal candidates are implicitly used by the sampling API below when no candidates are provided. + LLAMA_API void llama_sampling_set_logits( + struct llama_sampling * smpl, + const float * logits); + + /// @details Returns the current candidate tokens. + LLAMA_API llama_token_data_array * llama_sampling_get_candidates( + struct llama_sampling * smpl); + + // The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object. + // Each function can accept an array of token candidates. If the candidates are not provided, the internal + // candidates are used. The internal candidates are initialized by llama_sampling_set_logits(). + + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + LLAMA_API void llama_sampling_softmax( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API void llama_sampling_top_k( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API void llama_sampling_top_p( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + LLAMA_API void llama_sampling_min_p( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + LLAMA_API void llama_sampling_tail_free( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + LLAMA_API void llama_sampling_typical( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Apply temperature and entropy + LLAMA_API void llama_sampling_temp( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Apply constraints from grammar + LLAMA_API void llama_sampling_grammar( + struct llama_sampling * smpl, + llama_token_data_array * candidates); /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present); + LLAMA_API void llama_sampling_penalties( + struct llama_sampling * smpl, + llama_token_data_array * candidates); - /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 - /// @param logits Logits extracted from the original generation context. - /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. - /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - LLAMA_API void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale); - - /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sample_softmax( - struct llama_context * ctx, - llama_token_data_array * candidates); - - /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k( - struct llama_context * ctx, - llama_token_data_array * candidates, - int32_t k, - size_t min_keep); - - /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); - - /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API void llama_sample_min_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); - - /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free( - struct llama_context * ctx, - llama_token_data_array * candidates, - float z, - size_t min_keep); - - /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); - - /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. - LLAMA_API void llama_sample_entropy( - struct llama_context * ctx, - llama_token_data_array * candidates_p, - float min_temp, - float max_temp, - float exponent_val); - - LLAMA_API void llama_sample_temp( - struct llama_context * ctx, - llama_token_data_array * candidates, - float temp); - - /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - int32_t m, - float * mu); - - /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat_v2( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - float * mu); + /// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + LLAMA_API llama_token llama_sampling_sample_mirostat( + struct llama_sampling * smpl, + llama_token_data_array * candidates); /// @details Selects the token with the highest probability. - /// Does not compute the token probabilities. Use llama_sample_softmax() instead. - LLAMA_API llama_token llama_sample_token_greedy( - struct llama_context * ctx, - llama_token_data_array * candidates); + /// Does not compute the token probabilities. Use llama_sampling_softmax() instead. + LLAMA_API llama_token llama_sampling_sample_greedy( + struct llama_sampling * smpl, + llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. - LLAMA_API llama_token llama_sample_token( - struct llama_context * ctx, - llama_token_data_array * candidates); + /// @details Randomly selects a token from the candidates based on their probability distribution. + LLAMA_API llama_token llama_sampling_sample_dist( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers"). + LLAMA_API llama_token llama_sampling_sample( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Accepts the sampled token into the sampling context. + /// - adds it to "prev" tokens + /// - updates the grammar state (if apply_grammar is true) + LLAMA_API void llama_sampling_accept( + struct llama_sampling * smpl, + llama_token token, + bool apply_grammar); + + /// @details Get the number of accepted tokens so far (max of n_prev) + LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); + + /// @details Get the ith accepted token + /// @param ith [0, n_prev), ith == 0 is the last accepted token. + /// returns LLAMA_TOKEN_NULL if ith is out of bounds + LLAMA_API llama_token llama_sampling_prev( + const struct llama_sampling * smpl, + int32_t ith); + + /// @details Get the last accepted token + /// Same as llama_sampling_prev(smpl, 0) + /// returns LLAMA_TOKEN_NULL if there are no accepted tokens + LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); // // Model split @@ -1177,8 +1171,8 @@ extern "C" { // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - LLAMA_API void llama_print_timings(struct llama_context * ctx); - LLAMA_API void llama_reset_timings(struct llama_context * ctx); + LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl); + LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl); // Print system information LLAMA_API const char * llama_print_system_info(void); @@ -1193,59 +1187,4 @@ extern "C" { } #endif -// Internal API to be implemented by llama.cpp and used by tests/benchmarks only -#ifdef LLAMA_API_INTERNAL - -#include -#include -#include - -struct ggml_tensor; - -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -); - -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; - -using llama_grammar_rule = std::vector< llama_grammar_element>; -using llama_grammar_stack = std::vector; - -using llama_grammar_rules = std::vector; -using llama_grammar_stacks = std::vector; -using llama_grammar_candidates = std::vector; - -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks); - -std::vector llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates); - -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start); - -// Randomly selects a token from the candidates based on their probabilities using given std::mt19937. -// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); - -#endif // LLAMA_API_INTERNAL - #endif // LLAMA_H diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index b123d7331..8cd98bae4 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -4,10 +4,29 @@ #include "llama-sampling.h" #include +#include -// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as -// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( +// +// helpers +// + +// NOTE: assumes valid utf8 (but checks for overrun) +static std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); +} + +static std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; @@ -40,7 +59,7 @@ std::pair, llama_partial_utf8> decode_utf8( while (*pos != 0) { uint8_t first_byte = static_cast(*pos); uint8_t highbits = first_byte >> 4; - n_remain = lookup[highbits] - 1; + n_remain = lookup[highbits] - 1; if (n_remain < 0) { // invalid sequence, abort @@ -50,7 +69,7 @@ std::pair, llama_partial_utf8> decode_utf8( } uint8_t mask = (1 << (7 - n_remain)) - 1; - value = first_byte & mask; + value = first_byte & mask; ++pos; while (*pos != 0 && n_remain > 0) { @@ -67,12 +86,510 @@ std::pair, llama_partial_utf8> decode_utf8( return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } -const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { - return grammar->rules; +static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; } -llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { - return grammar->stacks; +static bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); +} + +static std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); +} + +static const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } + } + return pos; +} + +static const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; +} + +static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; +} + +static std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); +} + +static void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } +} + +static bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; + default: return false; + } +} + +static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + } + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); +} + +static void print_rule( + FILE * file, + uint32_t rule_id, + const llama_grammar_rule & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); +} + +// +// implementation +// + +uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) { + uint32_t next_id = static_cast(symbol_ids.size()); + auto result = symbol_ids.emplace(std::string(src, len), next_id); + return result.first->second; +} + +uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) { + uint32_t next_id = static_cast(symbol_ids.size()); + symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; +} + +void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) { + if (rules.size() <= rule_id) { + rules.resize(rule_id + 1); + } + rules[rule_id] = rule; +} + +const char * llama_grammar_parser::parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + llama_grammar_rule rule; + const char * pos = parse_sequence(src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(pos, rule_name, rule, is_nested); + } + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(rule_id, rule); + return pos; +} + +const char * llama_grammar_parser::parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested) { + size_t last_sym_start = rule.size(); + const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == rule.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + if (min_times == 0) { + rule.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + llama_grammar_rule rec_rule(prev_rule); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(prev_rule.size()); + uint32_t rec_rule_id = generate_symbol_id( rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule( rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = rule.size(); + while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = rule.size(); + while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < rule.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + rule.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(rule_name); + pos = parse_alternates(pos, rule_name, sub_rule_id, true); + last_sym_start = rule.size(); + // output reference to synthesized rule + rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); + } else { + break; + } + } + return pos; + } + +const char * llama_grammar_parser::parse_rule(const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); + } + +bool llama_grammar_parser::parse(const char * src) { + try { + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(pos); + } + // Validate the state to ensure that all rules are defined + for (const auto & rule : rules) { + if (rule.empty()) { + throw std::runtime_error("Undefined rule"); + } + for (const auto & elem : rule) { + if (elem.type == LLAMA_GRETYPE_RULE_REF) { + // Ensure that the rule at that location exists + if (elem.value >= rules.size() || rules[elem.value].empty()) { + // Get the name of the rule that is missing + for (const auto & kv : symbol_ids) { + if (kv.second == elem.value) { + throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); + } + } + } + } + } + } + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + rules.clear(); + return false; + } + + return true; +} + +void llama_grammar_parser::print(FILE * file) { + try { + std::map symbol_id_names; + for (const auto & kv : symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, rules[i]); + print_rule(file, uint32_t(i), rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } +} + +llama_grammar_stack llama_grammar_parser::c_rules() const { + llama_grammar_stack ret; + ret.reserve(rules.size()); + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; } // returns true iff pos points to the end of one of the definitions of a rule @@ -89,7 +606,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) static std::pair llama_grammar_match_char( const llama_grammar_element * pos, const uint32_t chr) { - bool found = false; bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; @@ -225,16 +741,92 @@ static void llama_grammar_advance_stack( } } -// takes a set of possible pushdown stacks on a grammar, which are required to -// be positioned at a character range (see `llama_grammar_advance_stack`), and -// produces the N possible stacks if the given char is accepted at those -// positions -void llama_grammar_accept( +static llama_grammar_candidates llama_grammar_reject_candidates( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const llama_grammar_candidates & candidates) { + GGML_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return {}; + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + + return rejects; +} + +static bool llama_grammar_detect_left_recursion( + const llama_grammar_rules & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { + if ((*rules_in_progress)[rule_index]) { + return true; + } + + (*rules_in_progress)[rule_index] = true; + + const llama_grammar_rule & rule = rules[rule_index]; + + // First check if the rule might produce the empty string. This could be done combined with the second + // step but it's more readable as two steps. + bool at_rule_start = true; + for (size_t i = 0; i < rule.size(); i++) { + if (llama_grammar_is_end_of_sequence(&rule[i])) { + if (at_rule_start) { + (*rules_may_be_empty)[rule_index] = true; + break; + } + at_rule_start = true; + } else { + at_rule_start = false; + } + } + + // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may + // be empty) + bool recurse_into_nonterminal = true; + for (size_t i = 0; i < rule.size(); i++) { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { + if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { + return true; + } + if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { + recurse_into_nonterminal = false; + } + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { + recurse_into_nonterminal = true; + } else { + recurse_into_nonterminal = false; + } + } + + (*rules_in_progress)[rule_index] = false; + (*rules_visited)[rule_index] = true; + + return false; +} + +const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { + return grammar->rules; +} + +llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { + return grammar->stacks; +} + +llama_grammar_stacks llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks) { - new_stacks.clear(); + const uint32_t chr) { + llama_grammar_stacks result; + result.reserve(stacks.size()); for (const auto & stack : stacks) { if (stack.empty()) { @@ -250,27 +842,11 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + llama_grammar_advance_stack(rules, new_stack, result); } } -} -static llama_grammar_candidates llama_grammar_reject_candidates( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const llama_grammar_candidates & candidates) { - GGML_ASSERT(!stacks.empty()); // REVIEW - - if (candidates.empty()) { - return {}; - } - - auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); - - for (size_t i = 1, size = stacks.size(); i < size; ++i) { - rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); - } - return rejects; + return result; } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -328,66 +904,13 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( return rejects; } -static bool llama_grammar_detect_left_recursion( - const llama_grammar_rules & rules, - size_t rule_index, - std::vector * rules_visited, - std::vector * rules_in_progress, - std::vector * rules_may_be_empty) { - if ((*rules_in_progress)[rule_index]) { - return true; - } - - (*rules_in_progress)[rule_index] = true; - - const llama_grammar_rule & rule = rules[rule_index]; - - // First check if the rule might produce the empty string. This could be done combined with the second - // step but it's more readable as two steps. - bool at_rule_start = true; - for (size_t i = 0; i < rule.size(); i++) { - if (llama_grammar_is_end_of_sequence(&rule[i])) { - if (at_rule_start) { - (*rules_may_be_empty)[rule_index] = true; - break; - } - at_rule_start = true; - } else { - at_rule_start = false; - } - } - - // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may - // be empty) - bool recurse_into_nonterminal = true; - for (size_t i = 0; i < rule.size(); i++) { - if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { - if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { - return true; - } - if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { - recurse_into_nonterminal = false; - } - } else if (llama_grammar_is_end_of_sequence(&rule[i])) { - recurse_into_nonterminal = true; - } else { - recurse_into_nonterminal = false; - } - } - - (*rules_in_progress)[rule_index] = false; - (*rules_visited)[rule_index] = true; - return false; -} - -// -// grammar - external -// +//////////////////// struct llama_grammar * llama_grammar_init_impl( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { const llama_grammar_element * pos; // copy rule definitions into vectors @@ -438,22 +961,100 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; +} + +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { + llama_grammar_parser parser; + + // if there is a grammar, parse it + if (!parser.parse(grammar_str)) { + return nullptr; + } + + // will be empty (default) if there are parse errors + if (parser.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; + } + + // Ensure that there is a "root" node. + if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { + fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + return nullptr; + } + + std::vector grammar_rules(parser.c_rules()); + + const size_t n_rules = grammar_rules.size(); + const size_t start_rule_index = parser.symbol_ids.at(grammar_root); + + const llama_grammar_element * pos; + + // copy rule definitions into vectors + llama_grammar_rules vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // Check for left recursion + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); + for (size_t i = 0; i < n_rules; i++) { + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + return nullptr; + } + } + + // loop over alternates of start rule to build initial stacks + llama_grammar_stacks stacks; + pos = vec_rules[start_rule_index].data(); + do { + llama_grammar_stack stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { - llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; +struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar) { + llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { result->stacks[is][ie] = &result->rules[ir0][ir1]; } } @@ -464,14 +1065,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram return result; } -void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { - GGML_ASSERT(grammar); - GGML_ASSERT(vocab); - - int64_t t_start_sample_us = ggml_time_us(); +void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * candidates) { + GGML_ASSERT(grammar.vocab != nullptr); bool allow_eog = false; - for (const auto & stack : grammar->stacks) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { allow_eog = true; break; @@ -486,33 +1084,31 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string & piece = vocab->cache_token_to_piece.at(id); + const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); - if (llama_token_is_eog_impl(*vocab, id)) { + if (llama_token_is_eog_impl(*grammar.vocab, id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); + candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { candidates->data[reject.index].logit = -INFINITY; } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { - const int64_t t_start_sample_us = ggml_time_us(); +void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { + GGML_ASSERT(grammar.vocab != nullptr); - if (llama_token_is_eog_impl(*vocab, token)) { - for (const auto & stack : grammar->stacks) { + if (llama_token_is_eog_impl(*grammar.vocab, token)) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } @@ -520,20 +1116,17 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc GGML_ABORT("fatal error"); } - const std::string & piece = vocab->cache_token_to_piece.at(token); + const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); - grammar->stacks = tmp_new_stacks; + llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it); + grammar.stacks = std::move(new_stacks); } - grammar->partial_utf8 = decoded.second; - GGML_ASSERT(!grammar->stacks.empty()); - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + grammar.partial_utf8 = decoded.second; + GGML_ASSERT(!grammar.stacks.empty()); } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 695ea0632..9b13354f6 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -2,11 +2,114 @@ #include "llama-impl.h" +#include + struct llama_vocab; -struct llama_sampling; + +// grammar element type +enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, +}; + +typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID +} llama_grammar_element; + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +llama_grammar_stacks llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + uint32_t chr); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates); + +struct llama_grammar_parser { + std::map symbol_ids; + + llama_grammar_rules rules; + + llama_grammar_stack c_rules() const; + + uint32_t get_symbol_id(const char * src, size_t len); + uint32_t generate_symbol_id(const std::string & base_name); + + void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); + + const char * parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + const char * parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested); + + const char * parse_rule(const char * src); + + bool parse(const char * src); + void print(FILE * file); +}; struct llama_grammar { - const llama_grammar_rules rules; + // note: allow null vocab for testing (not great) + const llama_vocab * vocab; + + const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; // buffer for partially generated UTF-8 sequence from accepted tokens @@ -17,23 +120,24 @@ struct llama_grammar { // internal API // +// note: needed for tests (not great) struct llama_grammar * llama_grammar_init_impl( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); +struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar); -void llama_grammar_sample_impl( - const struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, +// TODO: move the API below as member functions of llama_grammar +void llama_grammar_apply_impl( + const struct llama_grammar & grammar, llama_token_data_array * candidates); -void llama_grammar_accept_token_impl( - struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, +void llama_grammar_accept_impl( + struct llama_grammar & grammar, llama_token token); diff --git a/src/llama-impl.h b/src/llama-impl.h index 952774096..b67f511c0 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,8 +1,11 @@ #pragma once -#define LLAMA_API_INTERNAL #include "llama.h" +#include +#include +#include + #ifdef __GNUC__ #ifdef __MINGW32__ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) @@ -45,3 +48,114 @@ static void replace_all(std::string & s, const std::string & search, const std:: builder.append(s, last_pos, std::string::npos); s = std::move(builder); } + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +); + +// the ring buffer works similarly to std::deque, but with a fixed capacity +template +struct ring_buffer { + ring_buffer() {} + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + //T & operator[](size_t i) { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + //const T & at(size_t i) const { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector data; +}; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8f4841d9d..8abfc3fc6 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1,5 +1,8 @@ #include "llama-sampling.h" +#include "llama-vocab.h" +#include "llama-grammar.h" + #include #include #include @@ -21,18 +24,104 @@ static void llama_log_softmax(float * array, size_t size) { } } -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { +llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) { +} + +llama_sampling::~llama_sampling() { + if (grammar) { + llama_grammar_free_impl(grammar); + } +} + +struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) { + auto * result = new llama_sampling(vocab); + + result->params = params; + + result->prev = ring_buffer(params.n_prev); + + for (int i = 0; i < params.n_samplers; ++i) { + result->samplers.push_back(params.samplers[i]); + } + + llama_sampling_set_rng_seed_impl(*result, params.seed); + + return result; +} + +void llama_sampling_free_impl(struct llama_sampling * sampling) { + delete sampling; +} + +struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) { + auto * result = new llama_sampling(smpl.vocab); + + result->params = smpl.params; + + result->grammar_str = smpl.grammar_str; + result->grammar_root = smpl.grammar_root; + + result->logit_bias = smpl.logit_bias; + + if (smpl.grammar) { + result->grammar = llama_grammar_cp_impl(*smpl.grammar); + } + + result->rng = smpl.rng; + result->prev = smpl.prev; + + return result; +} + +void llama_sampling_reset_impl(struct llama_sampling & smpl) { + if (smpl.grammar) { + llama_grammar_free_impl(smpl.grammar); + smpl.grammar = nullptr; + } + + if (!smpl.grammar_str.empty()) { + smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data()); + } + + smpl.prev.clear(); +} + +void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { seed = time(NULL); } - smpl->rng.seed(seed); + smpl.rng.seed(seed); } -void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - GGML_ASSERT(candidates->size > 0); +void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) { + if (smpl.grammar) { + llama_grammar_free_impl(smpl.grammar); + smpl.grammar = nullptr; + } - const int64_t t_start_sample_us = ggml_time_us(); + if (grammar_str != nullptr && grammar_str[0] != '\0') { + smpl.grammar_str = grammar_str; + smpl.grammar_root = grammar_root; + + smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root); + } else { + smpl.grammar_str.clear(); + smpl.grammar_root.clear(); + } +} + +void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + smpl.logit_bias.clear(); + smpl.logit_bias.reserve(n_logit_bias); + + for (int32_t i = 0; i < n_logit_bias; ++i) { + smpl.logit_bias.push_back(logit_bias[i]); + } +} + +void llama_sampling_softmax_impl(llama_token_data_array * candidates) { + GGML_ASSERT(candidates->size > 0); // Sort the logits in descending order if (!candidates->sorted) { @@ -44,28 +133,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar float max_l = candidates->data[0].logit; float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { float p = expf(candidates->data[i].logit - max_l); candidates->data[i].p = p; cum_sum += p; } + for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].p /= cum_sum; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; // } - const int64_t t_start_sample_us = ggml_time_us(); - if (k <= 0) { k = candidates->size; } @@ -101,10 +186,12 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra int ib = nbuckets - 1; for ( ; ib >= 0; --ib) { nhave += histo[ib]; - if (nhave >= k) break; + if (nhave >= k) { + break; + } } std::vector tmp_tokens(nhave); - auto ptr = tmp_tokens.data(); + auto * ptr = tmp_tokens.data(); std::vector bucket_ptrs; bucket_ptrs.reserve(nbuckets - ib); for (int j = nbuckets - 1; j >= ib; --j) { @@ -133,20 +220,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra candidates->sorted = true; } candidates->size = k; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sample_softmax_impl(smpl, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(candidates); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -165,19 +246,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra // Resize the output vector to keep only the top-p tokens candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } - const int64_t t_start_sample_us = ggml_time_us(); - bool min_p_applied = false; // if the candidates aren't sorted, try the unsorted implementation first @@ -226,19 +301,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra // Resize the output vector to keep only the matching tokens candidates->size = i; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(candidates); // Compute the first and second derivatives std::vector first_derivatives(candidates->size - 1); @@ -285,13 +355,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_ // Resize the output vector to keep only the tokens above the tail location candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -299,9 +365,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar } // Compute the softmax of logits and calculate entropy - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(candidates); float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { @@ -349,15 +413,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); candidates->size = new_candidates.size(); candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { - const int64_t t_start_sample_us = ggml_time_us(); - +void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if(candidates->size <= 1) { return; @@ -366,7 +424,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar // Calculate maximum possible entropy float max_entropy = -logf(1.0f / candidates->size); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + llama_sampling_softmax_impl(candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -398,13 +456,15 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar } // Re-compute softmax probabilities after scaling logits with dynamic temperature - double max_l_double = candidates->data[0].logit; + const double max_l_double = candidates->data[0].logit; + double cum_sum_double = 0.0; for (size_t i = 0; i < candidates->size; ++i) { double p = exp(candidates->data[i].logit - max_l_double); candidates->data[i].p = p; // Store the scaled probability cum_sum_double += p; } + for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities } @@ -416,44 +476,24 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); } #endif - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { - const int64_t t_start_sample_us = ggml_time_us(); - +void llama_sampling_temp_impl(llama_token_data_array * candidates, float temp) { for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].logit /= temp; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, +void llama_sampling_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) { + llama_grammar_apply_impl(grammar, candidates); +} + +void llama_sampling_penalties_impl( llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - // Create a frequency map to count occurrences of each token in last_tokens - std::unordered_map token_count; - for (size_t i = 0; i < penalty_last_n; ++i) { - token_count[last_tokens[i]]++; - } - + const llama_token_cnt & token_count, + float penalty_repeat, + float penalty_freq, + float penalty_present) { // Apply frequency and presence penalties to the candidates for (size_t i = 0; i < candidates->size; ++i) { const auto token_iter = token_count.find(candidates->data[i].id); @@ -475,43 +515,10 @@ void llama_sample_repetition_penalties_impl( } candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, - float * logits, - float * logits_guidance, - float scale) { - GGML_ASSERT(smpl); - - const auto t_start_sample_us = ggml_time_us(); - const auto n_vocab = smpl->n_vocab; - - llama_log_softmax(logits, n_vocab); - llama_log_softmax(logits_guidance, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - auto & l = logits[i]; - const auto & g = logits_guidance[i]; - - l = scale * (l - g) + g; - } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; -} - -llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - GGML_ASSERT(smpl); - - const int32_t n_vocab = float(smpl->n_vocab); - - int64_t t_start_sample_us = ggml_time_us(); - - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { + llama_sampling_softmax_impl(candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -527,13 +534,11 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1); - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); + llama_sampling_top_k_impl(candidates, int(k), 1); + llama_token X = llama_sampling_sample_dist_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -543,93 +548,88 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama float e = observed_surprise - tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; + mu = mu - eta * e; - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; return X; } -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); - - llama_sample_softmax_impl(smpl, candidates); +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { + llama_sampling_softmax_impl(candidates); // Truncate the words with surprise values greater than mu candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > *mu; + return -log2f(candidate.p) > mu; })); if (candidates->size == 0) { candidates->size = 1; } - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } - // Normalize the probabilities of the remaining words - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(candidates); // Sample the next word X from the remaining words - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); + llama_token X = llama_sampling_sample_dist_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { return candidate.id == X; })); + float observed_surprise = -log2f(candidates->data[X_idx].p); float e = observed_surprise - tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; + mu = mu - eta * e; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } return X; } -llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - const int64_t t_start_sample_us = ggml_time_us(); - +llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidates) { // Find max element auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); llama_token result = max_iter->id; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - } + return result; } -llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { - GGML_ASSERT(smpl); - - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { + llama_sampling_softmax_impl(candidates); std::vector probs; probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { probs.push_back(candidates->data[i].p); } std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); + const int idx = dist(rng); llama_token result = candidates->data[idx].id; - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - return result; } -llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); +void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar) { + smpl.prev.push_back(token); + + if (apply_grammar && smpl.grammar) { + llama_grammar_accept_impl(*smpl.grammar, token); + } +} + +llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith) { + if (ith < 0 || ith >= (int) smpl.prev.size()) { + return LLAMA_TOKEN_NULL; + } + + return smpl.prev.rat(ith); +} + +int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) { + return smpl.prev.size(); } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index f7f8e3ef7..c51542259 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,56 +1,107 @@ #pragma once -#include "llama-impl.h" +#include "llama-grammar.h" + +#include +#include + +struct llama_vocab; +struct llama_grammar; + +using llama_token_cnt = std::unordered_map; struct llama_sampling { - llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} + llama_sampling(const struct llama_vocab & vocab); + ~llama_sampling(); + + llama_sampling_params params; + + std::string grammar_str; + std::string grammar_root; + + std::vector logit_bias; // logit biases to apply + + // state std::mt19937 rng; - int32_t n_vocab = 0; + const struct llama_vocab & vocab; - mutable int64_t t_sample_us = 0; - mutable int32_t n_sample = 0; + std::vector samplers; - void reset_timings() const { - t_sample_us = 0; - n_sample = 0; - } + ring_buffer prev; + + struct llama_grammar * grammar = nullptr; + + // mirostat sampler state + float mirostat_mu; + + mutable int64_t t_sample_us = 0; + mutable int64_t t_grammar_us = 0; + mutable int64_t t_accept_us = 0; + + mutable int32_t n_sample = 0; + mutable int32_t n_grammar = 0; + mutable int32_t n_accept = 0; + + std::vector cur; + + llama_token_data_array cur_p; }; // // internal API // -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed); +struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params); -void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); +void llama_sampling_free_impl(struct llama_sampling * sampling); -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, +struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl); + +void llama_sampling_reset_impl(struct llama_sampling & smpl); + +// TODO: move the API below as member functions of llama_sampling +void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed); +void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root); +void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); + +void llama_sampling_softmax_impl (struct llama_token_data_array * candidates); +void llama_sampling_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_sampling_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep); +void llama_sampling_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_sampling_temp_impl (struct llama_token_data_array * candidates, float temp); +void llama_sampling_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar); + +void llama_sampling_penalties_impl( llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, + const llama_token_cnt & token_count, float penalty_repeat, float penalty_freq, float penalty_present); -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, - float * logits, - float * logits_guidance, - float scale); +/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); -llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); -llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); +/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); +llama_token llama_sampling_sample_greedy_impl(struct llama_token_data_array * candidates); +llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); + +void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar); + +llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith); +int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 6e8f30be4..dc4b5f12f 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -18,6 +18,8 @@ struct llama_vocab { tattr attr; }; + uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -62,8 +64,6 @@ struct llama_vocab { int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; }; -const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx); - // // internal API // @@ -76,6 +76,7 @@ std::vector llama_tokenize_internal( bool add_special, bool parse_special = false); +// TODO: move the API below as member functions of llama_vocab llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch); const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token); diff --git a/src/llama.cpp b/src/llama.cpp index 883559716..6d24109d8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,6 +1,5 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-grammar.h" #include "llama-sampling.h" #include "unicode.h" @@ -148,6 +147,19 @@ static void zeros(std::ofstream & file, size_t n) { } } +struct time_meas { + time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {} + + ~time_meas() { + t_acc += ggml_time_us() - t_start_us; + } + + const int64_t t_start_us; + + int64_t & t_acc; +}; + + LLAMA_ATTRIBUTE_FORMAT(1, 2) static std::string format(const char * fmt, ...) { va_list ap; @@ -3179,7 +3191,6 @@ struct llama_sbatch { struct llama_context { llama_context(const llama_model & model) : model(model) - , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {} @@ -3196,7 +3207,6 @@ struct llama_context { const struct llama_model & model; struct llama_cparams cparams; - struct llama_sampling sampling; struct llama_sbatch sbatch; struct llama_kv_cache kv_self; struct llama_control_vector cvec; @@ -3217,16 +3227,16 @@ struct llama_context { bool has_evaluated_once = false; - int64_t t_start_us; - int64_t t_load_us; - int64_t t_p_eval_us = 0; - int64_t t_eval_us = 0; + mutable int64_t t_start_us; + mutable int64_t t_load_us; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; - int64_t t_compute_start_us = 0; - int64_t n_queued_tokens = 0; + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; - int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - int32_t n_eval = 0; // number of eval calls + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls // host buffer for the model output (logits and embeddings) ggml_backend_buffer_t buf_output = nullptr; @@ -6239,6 +6249,7 @@ static void llm_load_vocab( const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + vocab.n_vocab = n_vocab; vocab.id_to_token.resize(n_vocab); for (uint32_t i = 0; i < n_vocab; i++) { @@ -17863,7 +17874,6 @@ struct llama_model_params llama_model_default_params() { struct llama_context_params llama_context_default_params() { struct llama_context_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, @@ -17896,6 +17906,36 @@ struct llama_context_params llama_context_default_params() { return result; } +struct llama_sampling_params llama_sampling_default_params() { + struct llama_sampling_params result = { + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_prev =*/ 64, + /*.n_probs =*/ 0, + /*.min_keep =*/ 0, + /*.top_k =*/ 40, + /*.top_p =*/ 0.95f, + /*.min_p =*/ 0.05f, + /*.tfs_z =*/ 1.00f, + /*.typ_p =*/ 1.00f, + /*.temp =*/ 0.80f, + /*.dynatemp_range =*/ 0.00f, + /*.dynatemp_exponent =*/ 1.00f, + /*.penalty_last_n =*/ 64, + /*.penalty_repeat =*/ 1.00f, + /*.penalty_freq =*/ 0.00f, + /*.penalty_present =*/ 0.00f, + /*.mirostat =*/ 0, + /*.mirostat_tau =*/ 5.00f, + /*.mirostat_eta =*/ 0.10f, + /*.n_samplers =*/ 3, + /*.samplers =*/ { LLAMA_SAMPLER_TYPE_TEMPERATURE, LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, }, + /*.penalize_nl =*/ false, + /*.ignore_eos =*/ false, + }; + + return result; +} + struct llama_model_quantize_params llama_model_quantize_default_params() { struct llama_model_quantize_params result = { /*.nthread =*/ 0, @@ -18149,10 +18189,6 @@ struct llama_context * llama_new_context_with_model( cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; } - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); @@ -18163,10 +18199,10 @@ struct llama_context * llama_new_context_with_model( ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; - ctx->sampling.rng = std::mt19937(params.seed); - ctx->logits_all = params.logits_all; + ctx->logits_all = params.logits_all; + // build worst-case graph for encoder if a model contains encoder - ctx->is_encoding = llama_model_has_encoder(model); + ctx->is_encoding = llama_model_has_encoder(model); uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k; @@ -18443,14 +18479,6 @@ void llama_free(struct llama_context * ctx) { delete ctx; } -const struct llama_model * llama_get_model(const struct llama_context * ctx) { - return &ctx->model; -} - -const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) { - return &ctx->model.vocab; -} - uint32_t llama_n_ctx(const struct llama_context * ctx) { return ctx->cparams.n_ctx; } @@ -18471,6 +18499,30 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { return model->vocab.type; } +int32_t llama_n_vocab(const struct llama_model * model) { + return model->hparams.n_vocab; +} + +int32_t llama_n_ctx_train(const struct llama_model * model) { + return model->hparams.n_ctx_train; +} + +int32_t llama_n_embd(const struct llama_model * model) { + return model->hparams.n_embd; +} + +int32_t llama_n_layer(const struct llama_model * model) { + return model->hparams.n_layer; +} + +const struct llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; +} + +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { + return ctx->cparams.pooling_type; +} + enum llama_rope_type llama_rope_type(const struct llama_model * model) { switch (model->arch) { // these models do not use RoPE @@ -18534,26 +18586,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } -enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { - return ctx->cparams.pooling_type; -} - -int32_t llama_n_vocab(const struct llama_model * model) { - return model->hparams.n_vocab; -} - -int32_t llama_n_ctx_train(const struct llama_model * model) { - return model->hparams.n_ctx_train; -} - -int32_t llama_n_embd(const struct llama_model * model) { - return model->hparams.n_embd; -} - -int32_t llama_n_layer(const struct llama_model * model) { - return model->hparams.n_layer; -} - float llama_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } @@ -18970,14 +19002,14 @@ struct llama_data_write { // TODO: add more model-specific info which should prevent loading the session file if not identical } - void write_rng(const std::mt19937 & rng) { - std::ostringstream rng_ss; - rng_ss << rng; + //void write_rng(const std::mt19937 & rng) { + // std::ostringstream rng_ss; + // rng_ss << rng; - const std::string & rng_str = rng_ss.str(); + // const std::string & rng_str = rng_ss.str(); - write_string(rng_str); - } + // write_string(rng_str); + //} void write_output_ids(struct llama_context * ctx) { llama_output_reorder(ctx); @@ -19197,17 +19229,17 @@ struct llama_data_read { // TODO: add more info which needs to be identical but which is not verified otherwise } - void read_rng(std::mt19937 & rng) { - std::string rng_str; - read_string(rng_str); + //void read_rng(std::mt19937 & rng) { + // std::string rng_str; + // read_string(rng_str); - std::istringstream rng_ss(rng_str); - rng_ss >> rng; + // std::istringstream rng_ss(rng_str); + // rng_ss >> rng; - if (rng_ss.fail()) { - throw std::runtime_error("failed to load RNG state"); - } - } + // if (rng_ss.fail()) { + // throw std::runtime_error("failed to load RNG state"); + // } + //} void read_output_ids(struct llama_context * ctx) { std::vector output_pos; @@ -19637,8 +19669,6 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da data_ctx.write_model_info(ctx); - data_ctx.write_rng(ctx->sampling.rng); - // copy outputs data_ctx.write_output_ids(ctx); data_ctx.write_logits(ctx); @@ -19676,9 +19706,6 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da data_ctx.read_model_info(ctx); - // set rng - data_ctx.read_rng(ctx->sampling.rng); - // set outputs data_ctx.read_output_ids(ctx); data_ctx.read_logits(ctx); @@ -20081,8 +20108,9 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG GGML_ABORT("fatal error"); -#endif +#else return nullptr; +#endif } } @@ -20130,8 +20158,9 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG GGML_ABORT("fatal error"); -#endif +#else return nullptr; +#endif } } @@ -20564,125 +20593,350 @@ int32_t llama_chat_apply_template( return res; } -// -// grammar -// - -struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { - return llama_grammar_init_impl(rules, n_rules, start_rule_index); -} - -void llama_grammar_free(struct llama_grammar * grammar) { - llama_grammar_free_impl(grammar); -} - -struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - return llama_grammar_copy_impl(grammar); -} - -void llama_grammar_sample( - const struct llama_grammar * grammar, - const struct llama_context * ctx, - llama_token_data_array * candidates) { - llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates); -} - -void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar) { - llama_grammar_sample(grammar, ctx, candidates); -} - -void llama_grammar_accept_token( - struct llama_grammar * grammar, - struct llama_context * ctx, - llama_token token) { - llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token); -} - // // sampling // -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - llama_set_rng_seed_impl(&ctx->sampling, seed); +struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { + return llama_sampling_init_impl(model->vocab, params); } -void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates); +void llama_sampling_free(struct llama_sampling * smpl) { + if (smpl == nullptr) { + return; + } + + llama_sampling_free_impl(smpl); } -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep); +struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { + return llama_sampling_cp_impl(*smpl); } -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_reset(struct llama_sampling * smpl) { + llama_sampling_reset_impl(*smpl); } -void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { + llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); } -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { - llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep); +void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); } -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) { + const int n_vocab = smpl->vocab.n_vocab; + + smpl->cur.resize(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + for (const auto & lb : smpl->logit_bias) { + smpl->cur[lb.token].logit += lb.bias; + } + + if (smpl->params.ignore_eos) { + smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY; + } + + smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; + + // apply penalties + { + const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit; + + llama_sampling_penalties(smpl, &smpl->cur_p); + + if (!smpl->params.penalize_nl) { + for (size_t idx = 0; idx < smpl->cur_p.size; idx++) { + if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) { + smpl->cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } } -void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { - llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val); +llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) { + return &smpl->cur_p; } -void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp); +void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_softmax_impl(candidates); } -void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); +void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); } -void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale) { - llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale); +void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); } -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu); +void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); } -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu); +void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); } -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates); +void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); } -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng); +void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + if (smpl->params.dynatemp_range > 0) { + const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); + const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); + + llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); + } else { + llama_sampling_temp_impl(candidates, smpl->params.temp); + } } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng); +void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_grammar_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + if (smpl->grammar) { + llama_sampling_grammar_impl(candidates, *smpl->grammar); + + smpl->n_grammar++; + } } +void llama_sampling_penalties( + struct llama_sampling * smpl, + llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); + + const float penalty_repeat = smpl->params.penalty_repeat; + const float penalty_freq = smpl->params.penalty_freq; + const float penalty_present = smpl->params.penalty_present; + + if ((penalty_last_n == 0) || + (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { + return; + } + + // Create a frequency map to count occurrences of each token in last_tokens + // TODO: move to sampling state and avoid reallocation + llama_token_cnt token_count; + for (size_t i = 0; i < penalty_last_n; ++i) { + token_count[smpl->prev.rat(i)]++; + } + + llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present); +} + +llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + const auto type = smpl->params.mirostat; + + llama_token res; + + if (type == 1) { + res = llama_sampling_sample_mirostat_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + 100, + smpl->vocab.n_vocab, + smpl->mirostat_mu); + } else if (type == 2) { + res = llama_sampling_sample_mirostat_v2_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + smpl->mirostat_mu); + } else { + GGML_ABORT("invalid mirostat type: %d", type); + } + + smpl->n_sample++; + + return res; +} + +llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + auto res = llama_sampling_sample_greedy_impl(candidates); + + smpl->n_sample++; + + return res; +} + +llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); + + smpl->n_sample++; + + return res; +} + +llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + const auto & params = smpl->params; + + const float temp = params.temp; + const int mirostat = params.mirostat; + + auto & cur_p = candidates; + + llama_token res = 0; + + if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { + // greedy sampling, with probs + llama_sampling_softmax_impl(cur_p); + res = cur_p->data[0].id; + } else if (temp == 0.0f) { + // greedy sampling, no probs + res = llama_sampling_sample_greedy(smpl, cur_p); + } else { + if (mirostat != 0) { + llama_sampling_temp(smpl, cur_p); + res = llama_sampling_sample_mirostat(smpl, cur_p); + } else { + for (const auto & sampler : smpl->samplers) { + switch (sampler) { + case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; + default : break; + } + } + + res = llama_sampling_sample_dist(smpl, cur_p); + + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); + + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); + // } + //} + + //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); + } + } + + smpl->n_sample++; + + return res; +} + +void llama_sampling_accept( + struct llama_sampling * smpl, + llama_token token, + bool apply_grammar) { + time_meas tm(smpl->t_accept_us); + + llama_sampling_accept_impl(*smpl, token, apply_grammar); + + smpl->n_accept++; +} + +int llama_sampling_n_prev(const struct llama_sampling * smpl) { + return llama_sampling_n_prev_impl(*smpl); +} + +llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) { + return llama_sampling_prev_impl(*smpl, ith); +} + +llama_token llama_sampling_last(const struct llama_sampling * smpl) { + return llama_sampling_prev_impl(*smpl, 0); +} + +// +// model split +// + int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { @@ -20707,30 +20961,32 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -struct llama_timings llama_get_timings(struct llama_context * ctx) { - struct llama_timings result = { - /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, - /*.t_end_ms =*/ 1.00 * ggml_time_ms(), - /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us, - /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, - /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, +void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl) { + const llama_timings timings = { + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), + /*.t_grammar_ms =*/ 1e-3 * (smpl ? smpl->t_grammar_us : 0.0), + /*.t_accept_ms =*/ 1e-3 * (smpl ? smpl->t_accept_us : 0.0), + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - /*.n_sample =*/ std::max(1, ctx->sampling.n_sample), - /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), - /*.n_eval =*/ std::max(1, ctx->n_eval), + /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0), + /*.n_grammar =*/ std::max(0, smpl ? smpl->n_grammar : 0), + /*.n_accept =*/ std::max(0, smpl ? smpl->n_accept : 0), + /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), }; - return result; -} - -void llama_print_timings(struct llama_context * ctx) { - const llama_timings timings = llama_get_timings(ctx); - LLAMA_LOG_INFO("\n"); LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); - LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling); + LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_grammar_ms, timings.n_grammar, timings.t_grammar_ms / timings.n_grammar, 1e3 / timings.t_grammar_ms * timings.n_grammar); + //LLAMA_LOG_INFO("%s: accept time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + // __func__, timings.t_accept_ms, timings.n_accept, timings.t_accept_ms / timings.n_accept, 1e3 / timings.t_accept_ms * timings.n_accept); LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", @@ -20738,12 +20994,16 @@ void llama_print_timings(struct llama_context * ctx) { LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } -void llama_reset_timings(struct llama_context * ctx) { +void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; - ctx->sampling.reset_timings(); + if (smpl) { + smpl->t_sample_us = smpl->n_sample = 0; + smpl->t_grammar_us = smpl->n_grammar = 0; + smpl->t_accept_us = smpl->n_accept = 0; + } } const char * llama_print_system_info(void) { @@ -20785,21 +21045,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { 1.0e-3 * ctx->t_eval_us / ctx->n_eval); fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); - fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", - 1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us); fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", 1.0e6 * ctx->n_eval / ctx->t_eval_us); fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); - fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", - 1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us); } // For internal test use diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 9c4e7d18e..788b02a6a 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -2,33 +2,18 @@ #undef NDEBUG #endif -#define LLAMA_API_INTERNAL - -#include "ggml.h" -#include "llama.h" -#include "grammar-parser.h" -#include "json-schema-to-grammar.h" #include "unicode.h" +#include "llama-grammar.h" +#include "json-schema-to-grammar.h" + #include #include #include using json = nlohmann::ordered_json; -static llama_grammar* build_grammar(const std::string & grammar_str) { - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); - - // Ensure we parsed correctly - assert(!parsed_grammar.rules.empty()); - - // Ensure we have a root node - assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); - - std::vector grammar_rules(parsed_grammar.c_rules()); - llama_grammar* grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - - return grammar; +static llama_grammar * build_grammar(const std::string & grammar_str) { + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); } static bool test_build_grammar_fails(const std::string & grammar_str) { @@ -45,17 +30,15 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { } static bool match_string(const std::string & input, llama_grammar * grammar) { - auto decoded = decode_utf8(input, {}); - - const auto & code_points = decoded.first; + const auto cpts = unicode_cpts_from_utf8(input); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + for (const auto & cpt : cpts) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, prev_stacks, *it, cur_stacks); + cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); if (cur_stacks.empty()) { // no stacks means that the grammar failed to match at this point @@ -77,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str()); fflush(stderr); - auto grammar = build_grammar(grammar_str); + auto * grammar = build_grammar(grammar_str); // Save the original grammar stacks so that we can reset after every new string we want to test const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar); @@ -143,7 +126,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, } // Clean up allocated memory - llama_grammar_free(grammar); + llama_grammar_free_impl(grammar); } static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector & passing_strings, const std::vector & failing_strings) { test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings); @@ -683,7 +666,8 @@ static void test_failure_missing_root() { term ::= number number ::= [0-9]+)"""; - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_str.c_str()); // Ensure we parsed correctly assert(!parsed_grammar.rules.empty()); @@ -705,7 +689,8 @@ static void test_failure_missing_reference() { fprintf(stderr, " Expected error: "); - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_str.c_str()); // Ensure we did NOT parsed correctly assert(parsed_grammar.rules.empty()); diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 5df5abb25..259172d99 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -3,7 +3,7 @@ #endif #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include @@ -22,7 +22,8 @@ static const char * type_str(llama_gretype type) { static void verify_parsing(const char *grammar_bytes, const std::vector> expected, const std::vector &expected_rules) { uint32_t index = 0; - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_bytes); std::map symbol_names; for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { @@ -129,9 +130,10 @@ static void verify_parsing(const char *grammar_bytes, const std::vector #include #include #include -#include "json-schema-to-grammar.h" -#include "grammar-parser.h" - static std::string trim(const std::string & source) { std::string s(source); s.erase(0,s.find_first_not_of(" \n\r\t")); @@ -40,7 +41,8 @@ struct TestCase { } void verify_expectation_parseable() const { try { - auto state = grammar_parser::parse(expected_grammar.c_str()); + llama_grammar_parser state; + state.parse(expected_grammar.c_str()); if (state.symbol_ids.find("root") == state.symbol_ids.end()) { throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar); } diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 1f3a267b3..6f1374ca8 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -2,16 +2,15 @@ #undef NDEBUG #endif -#define LLAMA_API_INTERNAL #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include #include int main() { - grammar_parser::parse_state parsed_grammar; + llama_grammar_parser parsed_grammar; std::vector> expected = { {"expr", 2}, @@ -117,7 +116,7 @@ int main() llama_grammar * grammar = NULL; std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); @@ -174,13 +173,13 @@ int main() }}; auto index = 0; - for (auto stack : llama_grammar_get_stacks(grammar)) + for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar)) { // compare stack to expected_stack for (uint32_t i = 0; i < stack.size(); i++) { - auto element = stack[i]; - auto expected_element = expected_stacks[index][i]; + const llama_grammar_element * element = stack[i]; + const llama_grammar_element & expected_element = expected_stacks[index][i]; // pretty print error message before asserting if (expected_element.type != element->type || expected_element.value != element->value) @@ -403,6 +402,8 @@ int main() delete[] candidate.code_points; candidate.code_points = nullptr; } - llama_grammar_free(grammar); + + llama_grammar_free_impl(grammar); + return 0; } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 6c2a5db9a..f5e32a741 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,5 +1,6 @@ #include "ggml.h" #include "llama.h" +#include "llama-sampling.h" #ifdef NDEBUG #undef NDEBUG @@ -20,6 +21,7 @@ static void dump(const llama_token_data_array * candidates) { static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -28,9 +30,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -49,9 +52,9 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -71,7 +75,7 @@ static void test_tfs(const std::vector & probs, const std::vector llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sample_tail_free(nullptr, &candidates_p, z, 1); + llama_sampling_tail_free_impl(&candidates_p, z, 1); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -82,6 +86,7 @@ static void test_tfs(const std::vector & probs, const std::vector static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -91,9 +96,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -112,7 +118,7 @@ static void test_typical(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & last_tokens, const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence ) { GGML_ASSERT(probs.size() == expected_probs.size()); const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -135,11 +142,16 @@ static void test_repetition_penalties( candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } + llama_token_cnt token_count; + for (size_t i = 0; i < last_tokens.size(); i++) { + token_count[last_tokens[i]]++; + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_sample_softmax(nullptr, &candidates_p); + llama_sampling_softmax_impl(&candidates_p); DUMP(&candidates_p); - llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); - llama_sample_softmax(nullptr, &candidates_p); + llama_sampling_penalties_impl(&candidates_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); + llama_sampling_softmax_impl(&candidates_p); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -148,8 +160,7 @@ static void test_repetition_penalties( } } -static void test_sampler_queue( - const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p +static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { std::vector candidates; candidates.reserve(n_vocab); @@ -165,16 +176,16 @@ static void test_sampler_queue( for (auto s : samplers_sequence) { switch (s){ - case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; + case 'k': llama_sampling_top_k_impl(&candidates_p, top_k, 1); break; case 'f': GGML_ABORT("tail_free test not implemented"); case 'y': GGML_ABORT("typical test not implemented"); - case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; - case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; + case 'p': llama_sampling_top_p_impl(&candidates_p, top_p, 1); break; + case 'm': llama_sampling_min_p_impl(&candidates_p, min_p, 1); break; case 't': GGML_ABORT("temperature test not implemented"); default : GGML_ABORT("Unknown sampler"); } - llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests + llama_sampling_softmax_impl(&candidates_p); // make sure tokens are sorted for tests const int size = candidates_p.size; @@ -259,13 +270,13 @@ int main(void) { test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);