diff --git a/llama.cpp b/llama.cpp index 581a8399d..978327a5b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k( } } - if (top_k > 0 && top_k < n_logits) { - sample_top_k(logits_id, top_k); - } - - float maxl = -std::numeric_limits::infinity(); - for (const auto & kv : logits_id) { - maxl = Max(maxl, kv.first); - } + sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits); // compute probs for the top k tokens std::vector probs; probs.reserve(logits_id.size()); + float maxl = logits_id[0].first; double sum = 0.0; for (const auto & kv : logits_id) { const float p = expf(kv.first - maxl); @@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k( break; } } - - cumsum = 1.0/cumsum; - for (int i = 0; i < (int) probs.size(); i++) { - probs[i] *= cumsum; - } } //printf("\n"); //for (int i = 0; i < (int) 10; i++) { - // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); + // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]); //} //printf("\n\n"); //exit(0);