diff --git a/common/sampling.cpp b/common/sampling.cpp index e8675a8c0..844ad7c53 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -132,7 +132,7 @@ static void sampler_queue( const float temp = params.temp; const float dynatemp_range = params.dynatemp_range; const float dynatemp_exponent = params.dynatemp_exponent; - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const int32_t top_k = params.top_k; const float top_p = params.top_p; const float min_p = params.min_p; const float tfs_z = params.tfs_z; diff --git a/llama.cpp b/llama.cpp index c45ae1d50..f8f5796a4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8585,6 +8585,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can // } const int64_t t_start_sample_us = ggml_time_us(); + + if (k <= 0) { + k = candidates->size; + } k = std::max(k, (int) min_keep); k = std::min(k, (int) candidates->size); diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index c3b3d6629..6374958fe 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -235,6 +235,8 @@ int main(void) { test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);