From cb1c0727bd59803b439b6a3af121c99e6393ff3d Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Mon, 21 Aug 2023 11:11:31 +0300 Subject: [PATCH] HellaSwag: split token evaluation into batches if needed (#2681) Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 39 +++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 682c39b16..2409db69f 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -122,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) { printf("\n"); } +std::vector hellaswag_evaluate_tokens(llama_context * ctx, const std::vector& tokens, int n_past, int n_batch, + int n_vocab, int n_thread) { + std::vector result; + result.reserve(tokens.size() * n_vocab); + size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch; + for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { + size_t n_tokens = tokens.size() - i_chunk * n_batch; + n_tokens = std::min(n_tokens, size_t(n_batch)); + if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return {}; + } + + const auto logits = llama_get_logits(ctx); + result.insert(result.end(), logits, logits + n_tokens * n_vocab); + + n_past += n_tokens; + } + return result; +} + void hellaswag_score(llama_context * ctx, const gpt_params & params) { // Calculates hellaswag score (acc_norm) from prompt // @@ -235,15 +256,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { query_embd.resize(32); } - // Evaluate the query - if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) { + auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); + if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } - auto query_logits = llama_get_logits(ctx); - - std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float)); + std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float)); const auto first_probs = softmax(tok_logits); hs_data[task_idx].ending_logprob_count[0] = 1; @@ -252,7 +271,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { // Calculate the logprobs over the ending for (size_t j = context_size; j < query_size - 1; j++) { - std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float)); + std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); const float prob = softmax(tok_logits)[query_embd[j + 1]]; @@ -271,7 +290,6 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { // Tokenize the query query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false); query_size = query_embd.size(); - //printf("Second query: %d\n",(int)query_size); // Stop if query wont fit the ctx window if (context_size + query_size > (size_t)params.n_ctx) { @@ -286,19 +304,18 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { //} // Evaluate the query - if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) { + logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads); + if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } - query_logits = llama_get_logits(ctx); - hs_data[task_idx].ending_logprob_count[ending_idx] = 1; hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]); // Calculate the logprobs over the ending for (size_t j = 0; j < query_size - 1; j++) { - std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float)); + std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); const float prob = softmax(tok_logits)[query_embd[j + 1]];