From b0f27361f3539a81d983a8b045f3c61e682d9fc0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Sep 2024 09:03:17 +0300 Subject: [PATCH] sampling : avoid expensive softmax during greedy sampling (#9605) * sampling : avoid expensive softmax during greedy sampling ggml-ci * speculative : fix default RNG seed + set sparams.n_probs * Update tests/test-sampling.cpp Co-authored-by: slaren * sampling : add clarifying comment [no ci] --------- Co-authored-by: slaren --- common/sampling.cpp | 10 ++++++- examples/speculative/speculative.cpp | 5 +++- include/llama.h | 1 + src/llama-sampling.cpp | 7 +++-- tests/test-sampling.cpp | 42 +++++++++++++++++++++++++++- 5 files changed, 59 insertions(+), 6 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index e51d07611..3dc7f1120 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -209,7 +209,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st GGML_ASSERT(false && "unknown mirostat version"); } } else { - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + if (params.n_probs > 0) { + // some use cases require to sample greedily, but still obtain the probabilities of the top tokens + // ref: https://github.com/ggerganov/llama.cpp/pull/9605 + // + // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but + // it is much faster, since we avoid sorting all tokens and should give a good approximation + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs)); + llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + } llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index fbac21811..adf6255e1 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -32,6 +32,9 @@ struct seq_draft { int main(int argc, char ** argv) { gpt_params params; + // needed to get candidate probs even for temp <= 0.0 + params.sparams.n_probs = 128; + if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; } @@ -49,7 +52,7 @@ int main(int argc, char ** argv) { // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; - std::default_random_engine rng(params.sparams.seed); + std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed); std::uniform_real_distribution<> u_dist; // init llama.cpp diff --git a/include/llama.h b/include/llama.h index f316a87ba..132937a07 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1066,6 +1066,7 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5299f5116..e255a8fc4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -3,13 +3,14 @@ #include "llama-vocab.h" #include "llama-grammar.h" -#include #include -#include -#include +#include #include #include #include +#include +#include +#include #include #include #include diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index d738b7a45..6e021c4c7 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,6 +1,5 @@ #include "ggml.h" #include "llama.h" -#include "llama-sampling.h" #ifdef NDEBUG #undef NDEBUG @@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); } +static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector & data, int n_iter) { + std::vector cur(data.size()); + std::copy(data.begin(), data.end(), cur.begin()); + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + llama_sampler_apply(cnstr, &cur_p); + llama_sampler_reset(cnstr); + const int64_t t_start = ggml_time_us(); + for (int i = 0; i < n_iter; i++) { + std::copy(data.begin(), data.end(), cur.begin()); + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + llama_sampler_apply(cnstr, &cur_p); + llama_sampler_reset(cnstr); + } + const int64_t t_end = ggml_time_us(); + llama_sampler_free(cnstr); + printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter); +} + +#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter)) + +static void test_perf() { + const int n_vocab = 1 << 17; + + std::vector data; + + data.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f); + data.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + BENCH(llama_sampler_init_top_k (40), data, 32); + BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32); + BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32); + BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32); + BENCH(llama_sampler_init_typical (0.5f, 1), data, 32); + BENCH(llama_sampler_init_softmax (), data, 32); +} + int main(void) { ggml_time_init(); @@ -316,5 +354,7 @@ int main(void) { printf("OK\n"); + test_perf(); + return 0; }