mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
HellaSwag: split token evaluation into batches if needed (#2681)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
9e232f0234
commit
cb1c0727bd
@ -122,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
|
||||||
|
int n_vocab, int n_thread) {
|
||||||
|
std::vector<float> result;
|
||||||
|
result.reserve(tokens.size() * n_vocab);
|
||||||
|
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
|
||||||
|
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
||||||
|
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
||||||
|
n_tokens = std::min(n_tokens, size_t(n_batch));
|
||||||
|
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto logits = llama_get_logits(ctx);
|
||||||
|
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
|
||||||
|
|
||||||
|
n_past += n_tokens;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
// Calculates hellaswag score (acc_norm) from prompt
|
// Calculates hellaswag score (acc_norm) from prompt
|
||||||
//
|
//
|
||||||
@ -235,15 +256,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
query_embd.resize(32);
|
query_embd.resize(32);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evaluate the query
|
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
||||||
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
|
if (logits.empty()) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto query_logits = llama_get_logits(ctx);
|
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
|
||||||
|
|
||||||
std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float));
|
|
||||||
const auto first_probs = softmax(tok_logits);
|
const auto first_probs = softmax(tok_logits);
|
||||||
|
|
||||||
hs_data[task_idx].ending_logprob_count[0] = 1;
|
hs_data[task_idx].ending_logprob_count[0] = 1;
|
||||||
@ -252,7 +271,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
// Calculate the logprobs over the ending
|
// Calculate the logprobs over the ending
|
||||||
for (size_t j = context_size; j < query_size - 1; j++) {
|
for (size_t j = context_size; j < query_size - 1; j++) {
|
||||||
|
|
||||||
std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
|
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
|
||||||
|
|
||||||
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
||||||
|
|
||||||
@ -271,7 +290,6 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
// Tokenize the query
|
// Tokenize the query
|
||||||
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
|
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
|
||||||
query_size = query_embd.size();
|
query_size = query_embd.size();
|
||||||
//printf("Second query: %d\n",(int)query_size);
|
|
||||||
|
|
||||||
// Stop if query wont fit the ctx window
|
// Stop if query wont fit the ctx window
|
||||||
if (context_size + query_size > (size_t)params.n_ctx) {
|
if (context_size + query_size > (size_t)params.n_ctx) {
|
||||||
@ -286,19 +304,18 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
//}
|
//}
|
||||||
|
|
||||||
// Evaluate the query
|
// Evaluate the query
|
||||||
if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) {
|
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
|
||||||
|
if (logits.empty()) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
query_logits = llama_get_logits(ctx);
|
|
||||||
|
|
||||||
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
|
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
|
||||||
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
|
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
|
||||||
|
|
||||||
// Calculate the logprobs over the ending
|
// Calculate the logprobs over the ending
|
||||||
for (size_t j = 0; j < query_size - 1; j++) {
|
for (size_t j = 0; j < query_size - 1; j++) {
|
||||||
std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
|
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
|
||||||
|
|
||||||
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user