From 238657db2364cfb728c694470a4a81702afea760 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Tue, 31 Oct 2023 14:44:49 -0500 Subject: [PATCH] samplers : Min-P sampler implementation [alternative to Top P/Top K] (#3841) * Introduce the new Min-P sampler by @kalomaze The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. * Min-P enabled and set to 0.05 default --------- Co-authored-by: Georgi Gerganov Co-authored-by: cebtenzzre --- common/common.cpp | 8 ++++++++ common/sampling.cpp | 6 ++++-- common/sampling.h | 1 + examples/main/README.md | 8 ++++++++ llama.cpp | 26 ++++++++++++++++++++++++++ llama.h | 7 +++++++ 6 files changed, 54 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c187128d6..dc4865e80 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } sparams.top_p = std::stof(argv[i]); + } else if (arg == "--min-p") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.min_p = std::stof(argv[i]); } else if (arg == "--temp") { if (++i >= argc) { invalid_param = true; @@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); + printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); @@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); + fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } diff --git a/common/sampling.cpp b/common/sampling.cpp index c4996c985..673d67a6d 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -89,10 +89,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" - "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, - params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, + params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); return std::string(result); @@ -110,6 +110,7 @@ llama_token llama_sampling_sample( const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; + const float min_p = params.min_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; @@ -190,6 +191,7 @@ llama_token llama_sampling_sample( llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); + llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); llama_sample_temp (ctx_main, &cur_p, temp); id = llama_sample_token(ctx_main, &cur_p); diff --git a/common/sampling.h b/common/sampling.h index 62ea6d4cf..7c9b8dcf2 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -14,6 +14,7 @@ typedef struct llama_sampling_params { int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled diff --git a/examples/main/README.md b/examples/main/README.md index a9561c383..a3428b487 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -208,6 +208,14 @@ Top-p sampling, also known as nucleus sampling, is another text generation metho Example usage: `--top-p 0.95` +### Min P Sampling + +- `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.05). + +The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. + +Example usage: `--min-p 0.05` + ### Tail Free Sampling (TFS) - `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled). diff --git a/llama.cpp b/llama.cpp index e599917a8..7ee589298 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7368,6 +7368,32 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p <= 0.0f || !candidates->size) { + return; + } + + llama_sample_softmax(ctx, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + float scale = candidates->data[0].p; // scale by max prob + size_t i = 1; // first token always matches + + for (; i < candidates->size; ++i) { + if (candidates->data[i].p < p * scale && i >= min_keep) { + break; // prob too small + } + } + + // Resize the output vector to keep only the matching tokens + candidates->size = i; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; diff --git a/llama.h b/llama.h index d727dbd9f..75fe391ef 100644 --- a/llama.h +++ b/llama.h @@ -598,6 +598,13 @@ extern "C" { float p, size_t min_keep); + /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + LLAMA_API void llama_sample_min_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( struct llama_context * ctx,