From 4953e9007f86327aabc8312a7211c18019a3a40e Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Fri, 7 Apr 2023 19:02:12 +0300 Subject: [PATCH] llama : always sort logits before nucleus sampling (#812) * Always sort logits before nucleus sampling * remove second normalization - fix windows build - remove normalization since std::discrete_distribution does not require it --- llama.cpp | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/llama.cpp b/llama.cpp index 581a8399d..978327a5b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k( } } - if (top_k > 0 && top_k < n_logits) { - sample_top_k(logits_id, top_k); - } - - float maxl = -std::numeric_limits::infinity(); - for (const auto & kv : logits_id) { - maxl = Max(maxl, kv.first); - } + sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits); // compute probs for the top k tokens std::vector probs; probs.reserve(logits_id.size()); + float maxl = logits_id[0].first; double sum = 0.0; for (const auto & kv : logits_id) { const float p = expf(kv.first - maxl); @@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k( break; } } - - cumsum = 1.0/cumsum; - for (int i = 0; i < (int) probs.size(); i++) { - probs[i] *= cumsum; - } } //printf("\n"); //for (int i = 0; i < (int) 10; i++) { - // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); + // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]); //} //printf("\n\n"); //exit(0);