mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-10-30 06:30:15 +01:00
winogrande: evaluate log-probs in parallel (#5036)
This is a relatively minor performance tweak resulting in ~10% speedup on my system. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
2b3b999cac
commit
7051aacfac
@ -458,7 +458,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
||||
return true;
|
||||
}
|
||||
|
||||
static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
|
||||
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
|
||||
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
|
||||
constexpr int k_token_chunk = 4;
|
||||
if (eval_results.size() != eval_pairs.size()) {
|
||||
@ -700,7 +700,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
}
|
||||
}
|
||||
// Then we do the actual calculation
|
||||
hellaswag_compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
||||
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
||||
|
||||
size_t ir = 0;
|
||||
|
||||
@ -906,6 +906,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||
std::vector<float> tok_logits(n_vocab);
|
||||
std::vector<float> batch_logits(n_vocab*n_ctx);
|
||||
|
||||
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
||||
std::vector<float> eval_results;
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency());
|
||||
|
||||
int n_correct = 0;
|
||||
int n_done = 0;
|
||||
|
||||
@ -956,6 +960,30 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||
return;
|
||||
}
|
||||
|
||||
eval_pairs.clear();
|
||||
for (size_t i = i0; i < i1; ++i) {
|
||||
auto & task = data[i];
|
||||
|
||||
const bool skip_choice =
|
||||
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
|
||||
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
|
||||
|
||||
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
||||
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
||||
size_t li = n_base1 - 1;
|
||||
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
||||
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[0][j+1]));
|
||||
}
|
||||
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
||||
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
||||
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
|
||||
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
||||
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[1][j+1]));
|
||||
}
|
||||
}
|
||||
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
|
||||
|
||||
size_t ir = 0;
|
||||
for (size_t i = i0; i < i1; ++i) {
|
||||
auto & task = data[i];
|
||||
|
||||
@ -964,54 +992,21 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
|
||||
|
||||
float score_1st = 0;
|
||||
bool is_nan_1st = false;
|
||||
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
|
||||
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
|
||||
size_t li = n_base1 - 1;
|
||||
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
|
||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
|
||||
const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]];
|
||||
if (std::isnan(prob) || !prob) {
|
||||
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
||||
prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1);
|
||||
is_nan_1st = true;
|
||||
break;
|
||||
}
|
||||
score_1st += std::log(prob);
|
||||
score_1st += eval_results[ir++];
|
||||
}
|
||||
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
|
||||
|
||||
float score_2nd = 0;
|
||||
bool is_nan_2nd = false;
|
||||
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
|
||||
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
|
||||
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
|
||||
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
|
||||
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
|
||||
const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]];
|
||||
if (std::isnan(prob) || !prob) {
|
||||
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
||||
prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2);
|
||||
is_nan_2nd = true;
|
||||
break;
|
||||
}
|
||||
score_2nd += std::log(prob);
|
||||
score_2nd += eval_results[ir++];
|
||||
}
|
||||
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
|
||||
|
||||
if (is_nan_1st || is_nan_2nd) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
|
||||
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
|
||||
printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size());
|
||||
printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size());
|
||||
printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix);
|
||||
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice);
|
||||
continue;
|
||||
}
|
||||
|
||||
int result = score_1st > score_2nd ? 1 : 2;
|
||||
|
||||
if (result == task.answer) {
|
||||
@ -1019,7 +1014,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||
}
|
||||
++n_done;
|
||||
|
||||
// Print the accumulated accuracy mean x 100
|
||||
// print the accumulated accuracy mean x 100
|
||||
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
|
||||
fflush(stdout);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user