llama : always sort logits before nucleus sampling (#812)

* Always sort logits before nucleus sampling

* remove second normalization

- fix windows build
- remove normalization since std::discrete_distribution does not require it
This commit is contained in:
Ivan Stepanov 2023-04-07 19:02:12 +03:00 committed by GitHub
parent cc9cee8e9e
commit 4953e9007f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
} }
} }
if (top_k > 0 && top_k < n_logits) { sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits);
sample_top_k(logits_id, top_k);
}
float maxl = -std::numeric_limits<float>::infinity();
for (const auto & kv : logits_id) {
maxl = Max(maxl, kv.first);
}
// compute probs for the top k tokens // compute probs for the top k tokens
std::vector<float> probs; std::vector<float> probs;
probs.reserve(logits_id.size()); probs.reserve(logits_id.size());
float maxl = logits_id[0].first;
double sum = 0.0; double sum = 0.0;
for (const auto & kv : logits_id) { for (const auto & kv : logits_id) {
const float p = expf(kv.first - maxl); const float p = expf(kv.first - maxl);
@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k(
break; break;
} }
} }
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
} }
//printf("\n"); //printf("\n");
//for (int i = 0; i < (int) 10; i++) { //for (int i = 0; i < (int) 10; i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
//} //}
//printf("\n\n"); //printf("\n\n");
//exit(0); //exit(0);