diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index df902fb1c..292502f87 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -325,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par double nll = 0.0; double nll2 = 0.0; + const int num_batches = (n_ctx + n_batch - 1) / n_batch; + + std::vector logits; + if (num_batches > 1) { + logits.reserve((size_t)n_ctx * n_vocab); + } + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); std::vector workers(std::thread::hardware_concurrency() - 1); @@ -333,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const int start = i * n_ctx; const int end = start + n_ctx; - const int num_batches = (n_ctx + n_batch - 1) / n_batch; - - std::vector logits; - const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache @@ -362,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // restore the original token in case it was set to BOS tokens[batch_start] = token_org; - const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + if (num_batches > 1) { + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + } } const auto t_end = std::chrono::high_resolution_clock::now(); @@ -392,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // last 256 tokens. Then, we split the input up into context window size chunks to // process the entire prompt. const int first = n_ctx/2; - process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); + process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); count += n_ctx - first - 1; @@ -406,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); } fflush(stdout); + + logits.clear(); } printf("\n");