diff --git a/common/common.cpp b/common/common.cpp index 059d7a76a..7c3e11875 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -905,7 +905,7 @@ llama_token llama_sample_token( llama_token id = 0; - float * logits = llama_get_logits(ctx) + idx * n_vocab; + float * logits = llama_get_logits_ith(ctx, idx); // Apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { diff --git a/common/common.h b/common/common.h index ed86aa95b..16e30b2f5 100644 --- a/common/common.h +++ b/common/common.h @@ -183,7 +183,7 @@ std::string llama_detokenize_bpe( // - ctx_guidance: context to use for classifier-free guidance, ignore if NULL // - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits(ctx) + idx * n_vocab +// - idx: sample from llama_get_logits_ith(ctx, idx) // // returns: // - token: sampled token diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index cf48ce0c0..08b082b33 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -150,7 +150,7 @@ int main(int argc, char ** argv) { } auto n_vocab = llama_n_vocab(ctx); - auto logits = llama_get_logits(ctx) + i_batch[i] * n_vocab; + auto logits = llama_get_logits_ith(ctx, i_batch[i]); std::vector candidates; candidates.reserve(n_vocab);