diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index dea694165..dbe445391 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1991,8 +1991,12 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, n_kv); } else { params.n_batch = std::min(params.n_batch, params.n_ctx); - // ensure there's at least enough seq_ids for HellaSwag - params.n_parallel = std::max(4, params.n_parallel); + if (params.kl_divergence) { + params.n_parallel = 1; + } else { + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); + } } if (params.ppl_stride > 0) {